1use std::fs::{File, OpenOptions};
61use std::io::{Read, Write};
62use std::path::{Path, PathBuf};
63use std::time::{Duration, Instant};
64
65use sochdb_core::SochDBError;
66
67#[derive(Debug)]
73pub enum LockError {
74 DatabaseLocked {
76 holder_pid: Option<u32>,
78 lock_path: PathBuf,
80 },
81 Timeout {
83 elapsed: Duration,
85 timeout: Duration,
87 },
88 StaleLock {
90 stale_pid: u32,
92 },
93 Io(std::io::Error),
95}
96
97impl std::fmt::Display for LockError {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 match self {
100 LockError::DatabaseLocked { holder_pid, lock_path } => {
101 if let Some(pid) = holder_pid {
102 write!(f, "Database is locked by process {} (lock file: {})",
103 pid, lock_path.display())
104 } else {
105 write!(f, "Database is locked (lock file: {})", lock_path.display())
106 }
107 }
108 LockError::Timeout { elapsed, timeout } => {
109 write!(f, "Lock acquisition timed out after {:?} (timeout: {:?})",
110 elapsed, timeout)
111 }
112 LockError::StaleLock { stale_pid } => {
113 write!(f, "Stale lock detected from crashed process {}", stale_pid)
114 }
115 LockError::Io(e) => write!(f, "Lock I/O error: {}", e),
116 }
117 }
118}
119
120impl std::error::Error for LockError {
121 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
122 match self {
123 LockError::Io(e) => Some(e),
124 _ => None,
125 }
126 }
127}
128
129impl From<std::io::Error> for LockError {
130 fn from(e: std::io::Error) -> Self {
131 LockError::Io(e)
132 }
133}
134
135impl From<LockError> for SochDBError {
136 fn from(e: LockError) -> Self {
137 match e {
138 LockError::DatabaseLocked { holder_pid, lock_path } => {
139 SochDBError::LockError(format!(
140 "Database locked by PID {:?} (lock: {})",
141 holder_pid, lock_path.display()
142 ))
143 }
144 LockError::Timeout { elapsed, timeout } => {
145 SochDBError::LockError(format!(
146 "Lock timeout after {:?} (max: {:?})", elapsed, timeout
147 ))
148 }
149 LockError::StaleLock { stale_pid } => {
150 SochDBError::LockError(format!(
151 "Stale lock from crashed process {}", stale_pid
152 ))
153 }
154 LockError::Io(e) => SochDBError::Io(e),
155 }
156 }
157}
158
159#[derive(Debug, Clone)]
165pub struct LockConfig {
166 pub timeout: Option<Duration>,
168 pub retry_interval: Duration,
170 pub detect_stale_locks: bool,
172 pub lock_file_name: String,
174}
175
176impl Default for LockConfig {
177 fn default() -> Self {
178 Self {
179 timeout: Some(Duration::from_secs(5)),
180 retry_interval: Duration::from_millis(100),
181 detect_stale_locks: true,
182 lock_file_name: ".lock".to_string(),
183 }
184 }
185}
186
187impl LockConfig {
188 pub fn no_wait() -> Self {
190 Self {
191 timeout: None,
192 ..Default::default()
193 }
194 }
195
196 pub fn with_timeout(timeout: Duration) -> Self {
198 Self {
199 timeout: Some(timeout),
200 ..Default::default()
201 }
202 }
203}
204
205pub struct DatabaseLock {
225 lock_file: File,
227 path: PathBuf,
229 our_pid: u32,
231}
232
233impl DatabaseLock {
234 pub fn acquire<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
259 Self::acquire_with_config(db_path, &LockConfig::default())
260 }
261
262 pub fn acquire_no_wait<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
267 Self::acquire_with_config(db_path, &LockConfig::no_wait())
268 }
269
270 pub fn acquire_with_timeout<P: AsRef<Path>>(
279 db_path: P,
280 timeout: Duration
281 ) -> std::result::Result<Self, LockError> {
282 Self::acquire_with_config(db_path, &LockConfig::with_timeout(timeout))
283 }
284
285 pub fn acquire_with_config<P: AsRef<Path>>(
287 db_path: P,
288 config: &LockConfig,
289 ) -> std::result::Result<Self, LockError> {
290 let db_path = db_path.as_ref();
291 let lock_path = db_path.join(&config.lock_file_name);
292
293 if !db_path.exists() {
295 std::fs::create_dir_all(db_path)?;
296 }
297
298 let deadline = config.timeout.map(|t| Instant::now() + t);
299 let our_pid = std::process::id();
300
301 loop {
302 let file = OpenOptions::new()
304 .create(true)
305 .read(true)
306 .write(true)
307 .open(&lock_path)?;
308
309 match Self::try_flock(&file, false) {
311 Ok(()) => {
312 Self::write_pid(&file, our_pid)?;
314
315 return Ok(Self {
316 lock_file: file,
317 path: lock_path,
318 our_pid,
319 });
320 }
321 Err(LockError::DatabaseLocked { .. }) => {
322 let mut should_retry = false;
326 if config.detect_stale_locks {
327 if let Some(holder_pid) = Self::read_pid(&file) {
328 if !Self::process_exists(holder_pid) {
329 drop(file);
332
333 if std::fs::remove_file(&lock_path).is_ok() {
335 should_retry = true;
336 }
337 }
338 }
339 }
340
341 if should_retry {
342 continue; }
344
345 if let Some(deadline) = deadline {
347 if Instant::now() >= deadline {
348 return Err(LockError::Timeout {
349 elapsed: config.timeout.unwrap_or_default(),
350 timeout: config.timeout.unwrap_or_default(),
351 });
352 }
353
354 std::thread::sleep(config.retry_interval);
356 continue;
357 } else {
358 return Err(LockError::DatabaseLocked {
361 holder_pid: None,
362 lock_path
363 });
364 }
365 }
366 Err(e) => return Err(e),
367 }
368 }
369 }
370
371 pub fn path(&self) -> &Path {
373 &self.path
374 }
375
376 pub fn pid(&self) -> u32 {
378 self.our_pid
379 }
380
381 pub fn get_lock_holder<P: AsRef<Path>>(db_path: P) -> Option<u32> {
385 let lock_path = db_path.as_ref().join(".lock");
386 let file = File::open(&lock_path).ok()?;
387 Self::read_pid(&file)
388 }
389
390 fn write_pid(file: &File, pid: u32) -> std::result::Result<(), LockError> {
392 use std::io::Seek;
393 let mut file = file;
394 file.seek(std::io::SeekFrom::Start(0))?;
395 file.set_len(0)?;
396 writeln!(file, "{}", pid)?;
397 file.sync_all()?;
398 Ok(())
399 }
400
401 fn read_pid(file: &File) -> Option<u32> {
403 use std::io::Seek;
404 let mut file = file;
405 let _ = file.seek(std::io::SeekFrom::Start(0));
406 let mut contents = String::new();
407 file.read_to_string(&mut contents).ok()?;
408 contents.trim().parse().ok()
409 }
410
411 #[cfg(unix)]
413 fn process_exists(pid: u32) -> bool {
414 let result = unsafe { libc::kill(pid as libc::pid_t, 0) };
417 if result == 0 {
418 true
419 } else {
420 let errno = std::io::Error::last_os_error().raw_os_error();
422 errno != Some(libc::ESRCH)
423 }
424 }
425
426 #[cfg(windows)]
427 fn process_exists(pid: u32) -> bool {
428 unsafe {
429 let handle = windows_sys::Win32::System::Threading::OpenProcess(
430 windows_sys::Win32::System::Threading::PROCESS_QUERY_LIMITED_INFORMATION,
431 0,
432 pid,
433 );
434 if handle == 0 || handle == -1 {
435 false
436 } else {
437 windows_sys::Win32::Foundation::CloseHandle(handle);
438 true
439 }
440 }
441 }
442
443 #[cfg(not(any(unix, windows)))]
444 fn process_exists(_pid: u32) -> bool {
445 true
447 }
448
449 #[cfg(unix)]
451 fn try_flock(file: &File, blocking: bool) -> std::result::Result<(), LockError> {
452 use std::os::unix::io::AsRawFd;
453
454 let fd = file.as_raw_fd();
455 let operation = if blocking {
456 libc::LOCK_EX
457 } else {
458 libc::LOCK_EX | libc::LOCK_NB
459 };
460
461 let result = unsafe { libc::flock(fd, operation) };
462
463 if result == 0 {
464 Ok(())
465 } else {
466 let err = std::io::Error::last_os_error();
467 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
468 Err(LockError::DatabaseLocked {
469 holder_pid: None,
470 lock_path: PathBuf::new(),
471 })
472 } else {
473 Err(LockError::Io(err))
474 }
475 }
476 }
477
478 #[cfg(windows)]
479 fn try_flock(file: &File, blocking: bool) -> std::result::Result<(), LockError> {
480 use std::os::windows::io::AsRawHandle;
481
482 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
483
484 let flags = windows_sys::Win32::Storage::FileSystem::LOCKFILE_EXCLUSIVE_LOCK
485 | if blocking { 0 } else { windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY };
486
487 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
488
489 let result = unsafe {
490 windows_sys::Win32::Storage::FileSystem::LockFileEx(
491 handle,
492 flags,
493 0,
494 1,
495 0,
496 &mut overlapped,
497 )
498 };
499
500 if result != 0 {
501 Ok(())
502 } else {
503 let err = std::io::Error::last_os_error();
504 if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
505 Err(LockError::DatabaseLocked {
506 holder_pid: None,
507 lock_path: PathBuf::new(),
508 })
509 } else {
510 Err(LockError::Io(err))
511 }
512 }
513 }
514
515 #[cfg(not(any(unix, windows)))]
516 fn try_flock(_file: &File, _blocking: bool) -> std::result::Result<(), LockError> {
517 Ok(())
520 }
521
522 #[cfg(unix)]
524 fn release(&self) {
525 use std::os::unix::io::AsRawFd;
526 let fd = self.lock_file.as_raw_fd();
527 unsafe { libc::flock(fd, libc::LOCK_UN) };
528 }
529
530 #[cfg(windows)]
531 fn release(&self) {
532 use std::os::windows::io::AsRawHandle;
533 let handle = self.lock_file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
534 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
535 unsafe {
536 windows_sys::Win32::Storage::FileSystem::UnlockFileEx(
537 handle,
538 0,
539 1,
540 0,
541 &mut overlapped,
542 );
543 }
544 }
545
546 #[cfg(not(any(unix, windows)))]
547 fn release(&self) {
548 }
550}
551
552impl Drop for DatabaseLock {
553 fn drop(&mut self) {
554 self.release();
555 let _ = std::fs::remove_file(&self.path);
558 }
559}
560
561#[repr(C)]
575#[derive(Debug, Clone, Copy, Default)]
576pub struct RwLockState {
577 pub reader_count: u32,
579 pub writer_intent: u32,
581 pub writer_active: u32,
583 pub _padding: u32,
585}
586
587#[derive(Debug, Clone, Copy, PartialEq, Eq)]
589pub enum ConnectionMode {
590 ReadOnly,
592 ReadWrite,
594}
595
596pub struct RwDatabaseLock {
603 lock_file: File,
605 path: PathBuf,
607 mode: ConnectionMode,
609 our_pid: u32,
611}
612
613impl RwDatabaseLock {
614 pub fn acquire_shared<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
619 Self::acquire_with_mode(db_path, ConnectionMode::ReadOnly, &LockConfig::default())
620 }
621
622 pub fn acquire_exclusive<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
627 Self::acquire_with_mode(db_path, ConnectionMode::ReadWrite, &LockConfig::default())
628 }
629
630 pub fn acquire_with_mode<P: AsRef<Path>>(
632 db_path: P,
633 mode: ConnectionMode,
634 config: &LockConfig,
635 ) -> std::result::Result<Self, LockError> {
636 let db_path = db_path.as_ref();
637 let lock_path = db_path.join(&config.lock_file_name);
638
639 if !db_path.exists() {
640 std::fs::create_dir_all(db_path)?;
641 }
642
643 let file = OpenOptions::new()
644 .create(true)
645 .read(true)
646 .write(true)
647 .open(&lock_path)?;
648
649 let our_pid = std::process::id();
650 let deadline = config.timeout.map(|t| Instant::now() + t);
651
652 loop {
653 match mode {
654 ConnectionMode::ReadOnly => {
655 if Self::try_shared_lock(&file)? {
657 return Ok(Self {
658 lock_file: file,
659 path: lock_path,
660 mode,
661 our_pid,
662 });
663 }
664 }
665 ConnectionMode::ReadWrite => {
666 if Self::try_exclusive_lock(&file)? {
668 return Ok(Self {
669 lock_file: file,
670 path: lock_path,
671 mode,
672 our_pid,
673 });
674 }
675 }
676 }
677
678 if let Some(deadline) = deadline {
680 if Instant::now() >= deadline {
681 return Err(LockError::Timeout {
682 elapsed: config.timeout.unwrap_or_default(),
683 timeout: config.timeout.unwrap_or_default(),
684 });
685 }
686 std::thread::sleep(config.retry_interval);
687 } else {
688 return Err(LockError::DatabaseLocked {
689 holder_pid: None,
690 lock_path,
691 });
692 }
693 }
694 }
695
696 pub fn mode(&self) -> ConnectionMode {
698 self.mode
699 }
700
701 pub fn is_readonly(&self) -> bool {
703 self.mode == ConnectionMode::ReadOnly
704 }
705
706 #[cfg(unix)]
707 fn try_shared_lock(file: &File) -> std::result::Result<bool, LockError> {
708 use std::os::unix::io::AsRawFd;
709 let fd = file.as_raw_fd();
710 let result = unsafe { libc::flock(fd, libc::LOCK_SH | libc::LOCK_NB) };
711 if result == 0 {
712 Ok(true)
713 } else {
714 let err = std::io::Error::last_os_error();
715 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
716 Ok(false)
717 } else {
718 Err(LockError::Io(err))
719 }
720 }
721 }
722
723 #[cfg(unix)]
724 fn try_exclusive_lock(file: &File) -> std::result::Result<bool, LockError> {
725 use std::os::unix::io::AsRawFd;
726 let fd = file.as_raw_fd();
727 let result = unsafe { libc::flock(fd, libc::LOCK_EX | libc::LOCK_NB) };
728 if result == 0 {
729 Ok(true)
730 } else {
731 let err = std::io::Error::last_os_error();
732 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
733 Ok(false)
734 } else {
735 Err(LockError::Io(err))
736 }
737 }
738 }
739
740 #[cfg(windows)]
741 fn try_shared_lock(file: &File) -> std::result::Result<bool, LockError> {
742 use std::os::windows::io::AsRawHandle;
743 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
744 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
745
746 let result = unsafe {
747 windows_sys::Win32::Storage::FileSystem::LockFileEx(
748 handle,
749 windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY,
750 0, 1, 0,
751 &mut overlapped,
752 )
753 };
754
755 if result != 0 {
756 Ok(true)
757 } else {
758 let err = std::io::Error::last_os_error();
759 if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
760 Ok(false)
761 } else {
762 Err(LockError::Io(err))
763 }
764 }
765 }
766
767 #[cfg(windows)]
768 fn try_exclusive_lock(file: &File) -> std::result::Result<bool, LockError> {
769 use std::os::windows::io::AsRawHandle;
770 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
771 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
772
773 let result = unsafe {
774 windows_sys::Win32::Storage::FileSystem::LockFileEx(
775 handle,
776 windows_sys::Win32::Storage::FileSystem::LOCKFILE_EXCLUSIVE_LOCK
777 | windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY,
778 0, 1, 0,
779 &mut overlapped,
780 )
781 };
782
783 if result != 0 {
784 Ok(true)
785 } else {
786 let err = std::io::Error::last_os_error();
787 if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
788 Ok(false)
789 } else {
790 Err(LockError::Io(err))
791 }
792 }
793 }
794
795 #[cfg(not(any(unix, windows)))]
796 fn try_shared_lock(_file: &File) -> std::result::Result<bool, LockError> {
797 Ok(true)
798 }
799
800 #[cfg(not(any(unix, windows)))]
801 fn try_exclusive_lock(_file: &File) -> std::result::Result<bool, LockError> {
802 Ok(true)
803 }
804
805 #[cfg(unix)]
806 fn release(&self) {
807 use std::os::unix::io::AsRawFd;
808 let fd = self.lock_file.as_raw_fd();
809 unsafe { libc::flock(fd, libc::LOCK_UN) };
810 }
811
812 #[cfg(windows)]
813 fn release(&self) {
814 use std::os::windows::io::AsRawHandle;
815 let handle = self.lock_file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
816 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
817 unsafe {
818 windows_sys::Win32::Storage::FileSystem::UnlockFileEx(handle, 0, 1, 0, &mut overlapped);
819 }
820 }
821
822 #[cfg(not(any(unix, windows)))]
823 fn release(&self) {}
824}
825
826impl Drop for RwDatabaseLock {
827 fn drop(&mut self) {
828 self.release();
829 }
830}
831
832#[cfg(test)]
837mod tests {
838 use super::*;
839 use std::thread;
840 use tempfile::TempDir;
841
842 #[test]
843 fn test_exclusive_lock_basic() {
844 let dir = TempDir::new().unwrap();
845 let db_path = dir.path();
846
847 let lock1 = DatabaseLock::acquire(db_path);
849 assert!(lock1.is_ok());
850
851 let lock2 = DatabaseLock::acquire_no_wait(db_path);
853 assert!(matches!(lock2, Err(LockError::DatabaseLocked { .. })));
854
855 drop(lock1);
857 let lock3 = DatabaseLock::acquire(db_path);
858 assert!(lock3.is_ok());
859 }
860
861 #[test]
862 fn test_acquire_default_timeout() {
863 let dir = TempDir::new().unwrap();
864 let db_path = dir.path().to_path_buf();
865
866 let _lock = DatabaseLock::acquire(&db_path).unwrap();
868
869 let db_path2 = db_path.clone();
872 let lock_holder = _lock;
873 let handle = thread::spawn(move || {
874 thread::sleep(Duration::from_millis(200));
875 drop(lock_holder);
876 });
877
878 let start = Instant::now();
880 let result = DatabaseLock::acquire(&db_path2);
881 let elapsed = start.elapsed();
882
883 assert!(result.is_ok(), "acquire() should succeed after lock is released");
884 assert!(elapsed >= Duration::from_millis(100), "should have waited for lock");
885 assert!(elapsed < Duration::from_secs(2), "should not wait too long");
886
887 handle.join().unwrap();
888 }
889
890 #[test]
891 fn test_lock_with_timeout() {
892 let dir = TempDir::new().unwrap();
893 let db_path = dir.path().to_path_buf();
894
895 let _lock = DatabaseLock::acquire(&db_path).unwrap();
897
898 let start = Instant::now();
900 let result = DatabaseLock::acquire_with_timeout(&db_path, Duration::from_millis(100));
901 let elapsed = start.elapsed();
902
903 assert!(matches!(result, Err(LockError::Timeout { .. })));
904 assert!(elapsed >= Duration::from_millis(100));
905 assert!(elapsed < Duration::from_millis(500)); }
907
908 #[test]
909 fn test_lock_pid_recorded() {
910 let dir = TempDir::new().unwrap();
911 let db_path = dir.path();
912
913 let lock = DatabaseLock::acquire(db_path).unwrap();
914 let our_pid = std::process::id();
915
916 assert_eq!(lock.pid(), our_pid);
917
918 let holder = DatabaseLock::get_lock_holder(db_path);
920 assert_eq!(holder, Some(our_pid));
921 }
922
923 #[test]
924 fn test_shared_lock_multiple_readers() {
925 let dir = TempDir::new().unwrap();
926 let db_path = dir.path();
927
928 let lock1 = RwDatabaseLock::acquire_shared(db_path);
930 let lock2 = RwDatabaseLock::acquire_shared(db_path);
931
932 assert!(lock1.is_ok());
933 assert!(lock2.is_ok());
934 }
935
936 #[test]
937 fn test_exclusive_blocks_shared() {
938 let dir = TempDir::new().unwrap();
939 let db_path = dir.path();
940
941 let _exclusive = RwDatabaseLock::acquire_exclusive(db_path).unwrap();
943
944 let shared = RwDatabaseLock::acquire_with_mode(
946 db_path,
947 ConnectionMode::ReadOnly,
948 &LockConfig::no_wait(),
949 );
950
951 assert!(matches!(shared, Err(LockError::DatabaseLocked { .. })));
952 }
953}