1use std::fs::{File, OpenOptions};
61use std::io::{Read, Write};
62use std::path::{Path, PathBuf};
63use std::time::{Duration, Instant};
64
65use sochdb_core::SochDBError;
66
67#[derive(Debug)]
73pub enum LockError {
74 DatabaseLocked {
76 holder_pid: Option<u32>,
78 lock_path: PathBuf,
80 },
81 Timeout {
83 elapsed: Duration,
85 timeout: Duration,
87 },
88 StaleLock {
90 stale_pid: u32,
92 },
93 Io(std::io::Error),
95}
96
97impl std::fmt::Display for LockError {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 match self {
100 LockError::DatabaseLocked {
101 holder_pid,
102 lock_path,
103 } => {
104 if let Some(pid) = holder_pid {
105 write!(
106 f,
107 "Database is locked by process {} (lock file: {})",
108 pid,
109 lock_path.display()
110 )
111 } else {
112 write!(f, "Database is locked (lock file: {})", lock_path.display())
113 }
114 }
115 LockError::Timeout { elapsed, timeout } => {
116 write!(
117 f,
118 "Lock acquisition timed out after {:?} (timeout: {:?})",
119 elapsed, timeout
120 )
121 }
122 LockError::StaleLock { stale_pid } => {
123 write!(f, "Stale lock detected from crashed process {}", stale_pid)
124 }
125 LockError::Io(e) => write!(f, "Lock I/O error: {}", e),
126 }
127 }
128}
129
130impl std::error::Error for LockError {
131 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
132 match self {
133 LockError::Io(e) => Some(e),
134 _ => None,
135 }
136 }
137}
138
139impl From<std::io::Error> for LockError {
140 fn from(e: std::io::Error) -> Self {
141 LockError::Io(e)
142 }
143}
144
145impl From<LockError> for SochDBError {
146 fn from(e: LockError) -> Self {
147 match e {
148 LockError::DatabaseLocked {
149 holder_pid,
150 lock_path,
151 } => SochDBError::LockError(format!(
152 "Database locked by PID {:?} (lock: {})",
153 holder_pid,
154 lock_path.display()
155 )),
156 LockError::Timeout { elapsed, timeout } => SochDBError::LockError(format!(
157 "Lock timeout after {:?} (max: {:?})",
158 elapsed, timeout
159 )),
160 LockError::StaleLock { stale_pid } => {
161 SochDBError::LockError(format!("Stale lock from crashed process {}", stale_pid))
162 }
163 LockError::Io(e) => SochDBError::Io(e),
164 }
165 }
166}
167
168#[derive(Debug, Clone)]
174pub struct LockConfig {
175 pub timeout: Option<Duration>,
177 pub retry_interval: Duration,
179 pub detect_stale_locks: bool,
181 pub lock_file_name: String,
183}
184
185impl Default for LockConfig {
186 fn default() -> Self {
187 Self {
188 timeout: Some(Duration::from_secs(5)),
189 retry_interval: Duration::from_millis(100),
190 detect_stale_locks: true,
191 lock_file_name: ".lock".to_string(),
192 }
193 }
194}
195
196impl LockConfig {
197 pub fn no_wait() -> Self {
199 Self {
200 timeout: None,
201 ..Default::default()
202 }
203 }
204
205 pub fn with_timeout(timeout: Duration) -> Self {
207 Self {
208 timeout: Some(timeout),
209 ..Default::default()
210 }
211 }
212}
213
214pub struct DatabaseLock {
234 lock_file: File,
236 path: PathBuf,
238 our_pid: u32,
240}
241
242impl DatabaseLock {
243 pub fn acquire<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
268 Self::acquire_with_config(db_path, &LockConfig::default())
269 }
270
271 pub fn acquire_no_wait<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
276 Self::acquire_with_config(db_path, &LockConfig::no_wait())
277 }
278
279 pub fn acquire_with_timeout<P: AsRef<Path>>(
288 db_path: P,
289 timeout: Duration,
290 ) -> std::result::Result<Self, LockError> {
291 Self::acquire_with_config(db_path, &LockConfig::with_timeout(timeout))
292 }
293
294 pub fn acquire_with_config<P: AsRef<Path>>(
296 db_path: P,
297 config: &LockConfig,
298 ) -> std::result::Result<Self, LockError> {
299 let db_path = db_path.as_ref();
300 let lock_path = db_path.join(&config.lock_file_name);
301
302 if !db_path.exists() {
304 std::fs::create_dir_all(db_path)?;
305 }
306
307 let deadline = config.timeout.map(|t| Instant::now() + t);
308 let our_pid = std::process::id();
309
310 loop {
311 let file = OpenOptions::new()
313 .create(true)
314 .read(true)
315 .write(true)
316 .open(&lock_path)?;
317
318 match Self::try_flock(&file, false) {
320 Ok(()) => {
321 Self::write_pid(&file, our_pid)?;
323
324 return Ok(Self {
325 lock_file: file,
326 path: lock_path,
327 our_pid,
328 });
329 }
330 Err(LockError::DatabaseLocked { .. }) => {
331 let mut should_retry = false;
335 if config.detect_stale_locks {
336 if let Some(holder_pid) = Self::read_pid(&file) {
337 if !Self::process_exists(holder_pid) {
338 drop(file);
341
342 if std::fs::remove_file(&lock_path).is_ok() {
344 should_retry = true;
345 }
346 }
347 }
348 }
349
350 if should_retry {
351 continue; }
353
354 if let Some(deadline) = deadline {
356 if Instant::now() >= deadline {
357 return Err(LockError::Timeout {
358 elapsed: config.timeout.unwrap_or_default(),
359 timeout: config.timeout.unwrap_or_default(),
360 });
361 }
362
363 std::thread::sleep(config.retry_interval);
365 continue;
366 } else {
367 return Err(LockError::DatabaseLocked {
370 holder_pid: None,
371 lock_path,
372 });
373 }
374 }
375 Err(e) => return Err(e),
376 }
377 }
378 }
379
380 pub fn path(&self) -> &Path {
382 &self.path
383 }
384
385 pub fn pid(&self) -> u32 {
387 self.our_pid
388 }
389
390 pub fn get_lock_holder<P: AsRef<Path>>(db_path: P) -> Option<u32> {
394 let lock_path = db_path.as_ref().join(".lock");
395 let file = File::open(&lock_path).ok()?;
396 Self::read_pid(&file)
397 }
398
399 fn write_pid(file: &File, pid: u32) -> std::result::Result<(), LockError> {
401 use std::io::Seek;
402 let mut file = file;
403 file.seek(std::io::SeekFrom::Start(0))?;
404 file.set_len(0)?;
405 writeln!(file, "{}", pid)?;
406 file.sync_all()?;
407 Ok(())
408 }
409
410 fn read_pid(file: &File) -> Option<u32> {
412 use std::io::Seek;
413 let mut file = file;
414 let _ = file.seek(std::io::SeekFrom::Start(0));
415 let mut contents = String::new();
416 file.read_to_string(&mut contents).ok()?;
417 contents.trim().parse().ok()
418 }
419
420 #[cfg(unix)]
422 fn process_exists(pid: u32) -> bool {
423 let result = unsafe { libc::kill(pid as libc::pid_t, 0) };
426 if result == 0 {
427 true
428 } else {
429 let errno = std::io::Error::last_os_error().raw_os_error();
431 errno != Some(libc::ESRCH)
432 }
433 }
434
435 #[cfg(windows)]
436 fn process_exists(pid: u32) -> bool {
437 unsafe {
438 let handle = windows_sys::Win32::System::Threading::OpenProcess(
439 windows_sys::Win32::System::Threading::PROCESS_QUERY_LIMITED_INFORMATION,
440 0,
441 pid,
442 );
443 if handle == 0 || handle == -1 {
444 false
445 } else {
446 windows_sys::Win32::Foundation::CloseHandle(handle);
447 true
448 }
449 }
450 }
451
452 #[cfg(not(any(unix, windows)))]
453 fn process_exists(_pid: u32) -> bool {
454 true
456 }
457
458 #[cfg(unix)]
460 fn try_flock(file: &File, blocking: bool) -> std::result::Result<(), LockError> {
461 use std::os::unix::io::AsRawFd;
462
463 let fd = file.as_raw_fd();
464 let operation = if blocking {
465 libc::LOCK_EX
466 } else {
467 libc::LOCK_EX | libc::LOCK_NB
468 };
469
470 let result = unsafe { libc::flock(fd, operation) };
471
472 if result == 0 {
473 Ok(())
474 } else {
475 let err = std::io::Error::last_os_error();
476 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
477 Err(LockError::DatabaseLocked {
478 holder_pid: None,
479 lock_path: PathBuf::new(),
480 })
481 } else {
482 Err(LockError::Io(err))
483 }
484 }
485 }
486
487 #[cfg(windows)]
488 fn try_flock(file: &File, blocking: bool) -> std::result::Result<(), LockError> {
489 use std::os::windows::io::AsRawHandle;
490
491 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
492
493 let flags = windows_sys::Win32::Storage::FileSystem::LOCKFILE_EXCLUSIVE_LOCK
494 | if blocking {
495 0
496 } else {
497 windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY
498 };
499
500 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED =
501 unsafe { std::mem::zeroed() };
502
503 let result = unsafe {
504 windows_sys::Win32::Storage::FileSystem::LockFileEx(
505 handle,
506 flags,
507 0,
508 1,
509 0,
510 &mut overlapped,
511 )
512 };
513
514 if result != 0 {
515 Ok(())
516 } else {
517 let err = std::io::Error::last_os_error();
518 if err.raw_os_error()
519 == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32)
520 {
521 Err(LockError::DatabaseLocked {
522 holder_pid: None,
523 lock_path: PathBuf::new(),
524 })
525 } else {
526 Err(LockError::Io(err))
527 }
528 }
529 }
530
531 #[cfg(not(any(unix, windows)))]
532 fn try_flock(_file: &File, _blocking: bool) -> std::result::Result<(), LockError> {
533 Ok(())
536 }
537
538 #[cfg(unix)]
540 fn release(&self) {
541 use std::os::unix::io::AsRawFd;
542 let fd = self.lock_file.as_raw_fd();
543 unsafe { libc::flock(fd, libc::LOCK_UN) };
544 }
545
546 #[cfg(windows)]
547 fn release(&self) {
548 use std::os::windows::io::AsRawHandle;
549 let handle = self.lock_file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
550 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED =
551 unsafe { std::mem::zeroed() };
552 unsafe {
553 windows_sys::Win32::Storage::FileSystem::UnlockFileEx(handle, 0, 1, 0, &mut overlapped);
554 }
555 }
556
557 #[cfg(not(any(unix, windows)))]
558 fn release(&self) {
559 }
561}
562
563impl Drop for DatabaseLock {
564 fn drop(&mut self) {
565 self.release();
566 let _ = std::fs::remove_file(&self.path);
569 }
570}
571
572#[repr(C)]
586#[derive(Debug, Clone, Copy, Default)]
587pub struct RwLockState {
588 pub reader_count: u32,
590 pub writer_intent: u32,
592 pub writer_active: u32,
594 pub _padding: u32,
596}
597
598#[derive(Debug, Clone, Copy, PartialEq, Eq)]
600pub enum ConnectionMode {
601 ReadOnly,
603 ReadWrite,
605}
606
607pub struct RwDatabaseLock {
614 lock_file: File,
616 path: PathBuf,
618 mode: ConnectionMode,
620 our_pid: u32,
622}
623
624impl RwDatabaseLock {
625 pub fn acquire_shared<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
630 Self::acquire_with_mode(db_path, ConnectionMode::ReadOnly, &LockConfig::default())
631 }
632
633 pub fn acquire_exclusive<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
638 Self::acquire_with_mode(db_path, ConnectionMode::ReadWrite, &LockConfig::default())
639 }
640
641 pub fn acquire_with_mode<P: AsRef<Path>>(
643 db_path: P,
644 mode: ConnectionMode,
645 config: &LockConfig,
646 ) -> std::result::Result<Self, LockError> {
647 let db_path = db_path.as_ref();
648 let lock_path = db_path.join(&config.lock_file_name);
649
650 if !db_path.exists() {
651 std::fs::create_dir_all(db_path)?;
652 }
653
654 let file = OpenOptions::new()
655 .create(true)
656 .read(true)
657 .write(true)
658 .open(&lock_path)?;
659
660 let our_pid = std::process::id();
661 let deadline = config.timeout.map(|t| Instant::now() + t);
662
663 loop {
664 match mode {
665 ConnectionMode::ReadOnly => {
666 if Self::try_shared_lock(&file)? {
668 return Ok(Self {
669 lock_file: file,
670 path: lock_path,
671 mode,
672 our_pid,
673 });
674 }
675 }
676 ConnectionMode::ReadWrite => {
677 if Self::try_exclusive_lock(&file)? {
679 return Ok(Self {
680 lock_file: file,
681 path: lock_path,
682 mode,
683 our_pid,
684 });
685 }
686 }
687 }
688
689 if let Some(deadline) = deadline {
691 if Instant::now() >= deadline {
692 return Err(LockError::Timeout {
693 elapsed: config.timeout.unwrap_or_default(),
694 timeout: config.timeout.unwrap_or_default(),
695 });
696 }
697 std::thread::sleep(config.retry_interval);
698 } else {
699 return Err(LockError::DatabaseLocked {
700 holder_pid: None,
701 lock_path,
702 });
703 }
704 }
705 }
706
707 pub fn mode(&self) -> ConnectionMode {
709 self.mode
710 }
711
712 pub fn is_readonly(&self) -> bool {
714 self.mode == ConnectionMode::ReadOnly
715 }
716
717 #[cfg(unix)]
718 fn try_shared_lock(file: &File) -> std::result::Result<bool, LockError> {
719 use std::os::unix::io::AsRawFd;
720 let fd = file.as_raw_fd();
721 let result = unsafe { libc::flock(fd, libc::LOCK_SH | libc::LOCK_NB) };
722 if result == 0 {
723 Ok(true)
724 } else {
725 let err = std::io::Error::last_os_error();
726 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
727 Ok(false)
728 } else {
729 Err(LockError::Io(err))
730 }
731 }
732 }
733
734 #[cfg(unix)]
735 fn try_exclusive_lock(file: &File) -> std::result::Result<bool, LockError> {
736 use std::os::unix::io::AsRawFd;
737 let fd = file.as_raw_fd();
738 let result = unsafe { libc::flock(fd, libc::LOCK_EX | libc::LOCK_NB) };
739 if result == 0 {
740 Ok(true)
741 } else {
742 let err = std::io::Error::last_os_error();
743 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
744 Ok(false)
745 } else {
746 Err(LockError::Io(err))
747 }
748 }
749 }
750
751 #[cfg(windows)]
752 fn try_shared_lock(file: &File) -> std::result::Result<bool, LockError> {
753 use std::os::windows::io::AsRawHandle;
754 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
755 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED =
756 unsafe { std::mem::zeroed() };
757
758 let result = unsafe {
759 windows_sys::Win32::Storage::FileSystem::LockFileEx(
760 handle,
761 windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY,
762 0,
763 1,
764 0,
765 &mut overlapped,
766 )
767 };
768
769 if result != 0 {
770 Ok(true)
771 } else {
772 let err = std::io::Error::last_os_error();
773 if err.raw_os_error()
774 == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32)
775 {
776 Ok(false)
777 } else {
778 Err(LockError::Io(err))
779 }
780 }
781 }
782
783 #[cfg(windows)]
784 fn try_exclusive_lock(file: &File) -> std::result::Result<bool, LockError> {
785 use std::os::windows::io::AsRawHandle;
786 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
787 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED =
788 unsafe { std::mem::zeroed() };
789
790 let result = unsafe {
791 windows_sys::Win32::Storage::FileSystem::LockFileEx(
792 handle,
793 windows_sys::Win32::Storage::FileSystem::LOCKFILE_EXCLUSIVE_LOCK
794 | windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY,
795 0,
796 1,
797 0,
798 &mut overlapped,
799 )
800 };
801
802 if result != 0 {
803 Ok(true)
804 } else {
805 let err = std::io::Error::last_os_error();
806 if err.raw_os_error()
807 == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32)
808 {
809 Ok(false)
810 } else {
811 Err(LockError::Io(err))
812 }
813 }
814 }
815
816 #[cfg(not(any(unix, windows)))]
817 fn try_shared_lock(_file: &File) -> std::result::Result<bool, LockError> {
818 Ok(true)
819 }
820
821 #[cfg(not(any(unix, windows)))]
822 fn try_exclusive_lock(_file: &File) -> std::result::Result<bool, LockError> {
823 Ok(true)
824 }
825
826 #[cfg(unix)]
827 fn release(&self) {
828 use std::os::unix::io::AsRawFd;
829 let fd = self.lock_file.as_raw_fd();
830 unsafe { libc::flock(fd, libc::LOCK_UN) };
831 }
832
833 #[cfg(windows)]
834 fn release(&self) {
835 use std::os::windows::io::AsRawHandle;
836 let handle = self.lock_file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
837 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED =
838 unsafe { std::mem::zeroed() };
839 unsafe {
840 windows_sys::Win32::Storage::FileSystem::UnlockFileEx(handle, 0, 1, 0, &mut overlapped);
841 }
842 }
843
844 #[cfg(not(any(unix, windows)))]
845 fn release(&self) {}
846}
847
848impl Drop for RwDatabaseLock {
849 fn drop(&mut self) {
850 self.release();
851 }
852}
853
854#[cfg(test)]
859mod tests {
860 use super::*;
861 use std::thread;
862 use tempfile::TempDir;
863
864 #[test]
865 fn test_exclusive_lock_basic() {
866 let dir = TempDir::new().unwrap();
867 let db_path = dir.path();
868
869 let lock1 = DatabaseLock::acquire(db_path);
871 assert!(lock1.is_ok());
872
873 let lock2 = DatabaseLock::acquire_no_wait(db_path);
875 assert!(matches!(lock2, Err(LockError::DatabaseLocked { .. })));
876
877 drop(lock1);
879 let lock3 = DatabaseLock::acquire(db_path);
880 assert!(lock3.is_ok());
881 }
882
883 #[test]
884 fn test_acquire_default_timeout() {
885 let dir = TempDir::new().unwrap();
886 let db_path = dir.path().to_path_buf();
887
888 let _lock = DatabaseLock::acquire(&db_path).unwrap();
890
891 let db_path2 = db_path.clone();
894 let lock_holder = _lock;
895 let handle = thread::spawn(move || {
896 thread::sleep(Duration::from_millis(200));
897 drop(lock_holder);
898 });
899
900 let start = Instant::now();
902 let result = DatabaseLock::acquire(&db_path2);
903 let elapsed = start.elapsed();
904
905 assert!(
906 result.is_ok(),
907 "acquire() should succeed after lock is released"
908 );
909 assert!(
910 elapsed >= Duration::from_millis(100),
911 "should have waited for lock"
912 );
913 assert!(elapsed < Duration::from_secs(2), "should not wait too long");
914
915 handle.join().unwrap();
916 }
917
918 #[test]
919 fn test_lock_with_timeout() {
920 let dir = TempDir::new().unwrap();
921 let db_path = dir.path().to_path_buf();
922
923 let _lock = DatabaseLock::acquire(&db_path).unwrap();
925
926 let start = Instant::now();
928 let result = DatabaseLock::acquire_with_timeout(&db_path, Duration::from_millis(100));
929 let elapsed = start.elapsed();
930
931 assert!(matches!(result, Err(LockError::Timeout { .. })));
932 assert!(elapsed >= Duration::from_millis(100));
933 assert!(elapsed < Duration::from_millis(500)); }
935
936 #[test]
937 fn test_lock_pid_recorded() {
938 let dir = TempDir::new().unwrap();
939 let db_path = dir.path();
940
941 let lock = DatabaseLock::acquire(db_path).unwrap();
942 let our_pid = std::process::id();
943
944 assert_eq!(lock.pid(), our_pid);
945
946 let holder = DatabaseLock::get_lock_holder(db_path);
948 assert_eq!(holder, Some(our_pid));
949 }
950
951 #[test]
952 fn test_shared_lock_multiple_readers() {
953 let dir = TempDir::new().unwrap();
954 let db_path = dir.path();
955
956 let lock1 = RwDatabaseLock::acquire_shared(db_path);
958 let lock2 = RwDatabaseLock::acquire_shared(db_path);
959
960 assert!(lock1.is_ok());
961 assert!(lock2.is_ok());
962 }
963
964 #[test]
965 fn test_exclusive_blocks_shared() {
966 let dir = TempDir::new().unwrap();
967 let db_path = dir.path();
968
969 let _exclusive = RwDatabaseLock::acquire_exclusive(db_path).unwrap();
971
972 let shared = RwDatabaseLock::acquire_with_mode(
974 db_path,
975 ConnectionMode::ReadOnly,
976 &LockConfig::no_wait(),
977 );
978
979 assert!(matches!(shared, Err(LockError::DatabaseLocked { .. })));
980 }
981}