1use std::fs::{File, OpenOptions};
61use std::io::{Read, Write};
62use std::path::{Path, PathBuf};
63use std::time::{Duration, Instant};
64
65use sochdb_core::SochDBError;
66
67#[derive(Debug)]
73pub enum LockError {
74 DatabaseLocked {
76 holder_pid: Option<u32>,
78 lock_path: PathBuf,
80 },
81 Timeout {
83 elapsed: Duration,
85 timeout: Duration,
87 },
88 StaleLock {
90 stale_pid: u32,
92 },
93 Io(std::io::Error),
95}
96
97impl std::fmt::Display for LockError {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 match self {
100 LockError::DatabaseLocked { holder_pid, lock_path } => {
101 if let Some(pid) = holder_pid {
102 write!(f, "Database is locked by process {} (lock file: {})",
103 pid, lock_path.display())
104 } else {
105 write!(f, "Database is locked (lock file: {})", lock_path.display())
106 }
107 }
108 LockError::Timeout { elapsed, timeout } => {
109 write!(f, "Lock acquisition timed out after {:?} (timeout: {:?})",
110 elapsed, timeout)
111 }
112 LockError::StaleLock { stale_pid } => {
113 write!(f, "Stale lock detected from crashed process {}", stale_pid)
114 }
115 LockError::Io(e) => write!(f, "Lock I/O error: {}", e),
116 }
117 }
118}
119
120impl std::error::Error for LockError {
121 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
122 match self {
123 LockError::Io(e) => Some(e),
124 _ => None,
125 }
126 }
127}
128
129impl From<std::io::Error> for LockError {
130 fn from(e: std::io::Error) -> Self {
131 LockError::Io(e)
132 }
133}
134
135impl From<LockError> for SochDBError {
136 fn from(e: LockError) -> Self {
137 match e {
138 LockError::DatabaseLocked { holder_pid, lock_path } => {
139 SochDBError::LockError(format!(
140 "Database locked by PID {:?} (lock: {})",
141 holder_pid, lock_path.display()
142 ))
143 }
144 LockError::Timeout { elapsed, timeout } => {
145 SochDBError::LockError(format!(
146 "Lock timeout after {:?} (max: {:?})", elapsed, timeout
147 ))
148 }
149 LockError::StaleLock { stale_pid } => {
150 SochDBError::LockError(format!(
151 "Stale lock from crashed process {}", stale_pid
152 ))
153 }
154 LockError::Io(e) => SochDBError::Io(e),
155 }
156 }
157}
158
159#[derive(Debug, Clone)]
165pub struct LockConfig {
166 pub timeout: Option<Duration>,
168 pub retry_interval: Duration,
170 pub detect_stale_locks: bool,
172 pub lock_file_name: String,
174}
175
176impl Default for LockConfig {
177 fn default() -> Self {
178 Self {
179 timeout: Some(Duration::from_secs(5)),
180 retry_interval: Duration::from_millis(100),
181 detect_stale_locks: true,
182 lock_file_name: ".lock".to_string(),
183 }
184 }
185}
186
187impl LockConfig {
188 pub fn no_wait() -> Self {
190 Self {
191 timeout: None,
192 ..Default::default()
193 }
194 }
195
196 pub fn with_timeout(timeout: Duration) -> Self {
198 Self {
199 timeout: Some(timeout),
200 ..Default::default()
201 }
202 }
203}
204
205pub struct DatabaseLock {
225 lock_file: File,
227 path: PathBuf,
229 our_pid: u32,
231}
232
233impl DatabaseLock {
234 pub fn acquire<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
252 Self::acquire_with_config(db_path, &LockConfig::no_wait())
253 }
254
255 pub fn acquire_with_timeout<P: AsRef<Path>>(
264 db_path: P,
265 timeout: Duration
266 ) -> std::result::Result<Self, LockError> {
267 Self::acquire_with_config(db_path, &LockConfig::with_timeout(timeout))
268 }
269
270 pub fn acquire_with_config<P: AsRef<Path>>(
272 db_path: P,
273 config: &LockConfig,
274 ) -> std::result::Result<Self, LockError> {
275 let db_path = db_path.as_ref();
276 let lock_path = db_path.join(&config.lock_file_name);
277
278 if !db_path.exists() {
280 std::fs::create_dir_all(db_path)?;
281 }
282
283 let deadline = config.timeout.map(|t| Instant::now() + t);
284 let our_pid = std::process::id();
285
286 loop {
287 let file = OpenOptions::new()
289 .create(true)
290 .read(true)
291 .write(true)
292 .open(&lock_path)?;
293
294 match Self::try_flock(&file, false) {
296 Ok(()) => {
297 Self::write_pid(&file, our_pid)?;
299
300 return Ok(Self {
301 lock_file: file,
302 path: lock_path,
303 our_pid,
304 });
305 }
306 Err(LockError::DatabaseLocked { .. }) => {
307 let mut should_retry = false;
311 if config.detect_stale_locks {
312 if let Some(holder_pid) = Self::read_pid(&file) {
313 if !Self::process_exists(holder_pid) {
314 drop(file);
317
318 if std::fs::remove_file(&lock_path).is_ok() {
320 should_retry = true;
321 }
322 }
323 }
324 }
325
326 if should_retry {
327 continue; }
329
330 if let Some(deadline) = deadline {
332 if Instant::now() >= deadline {
333 return Err(LockError::Timeout {
334 elapsed: config.timeout.unwrap_or_default(),
335 timeout: config.timeout.unwrap_or_default(),
336 });
337 }
338
339 std::thread::sleep(config.retry_interval);
341 continue;
342 } else {
343 return Err(LockError::DatabaseLocked {
346 holder_pid: None,
347 lock_path
348 });
349 }
350 }
351 Err(e) => return Err(e),
352 }
353 }
354 }
355
356 pub fn path(&self) -> &Path {
358 &self.path
359 }
360
361 pub fn pid(&self) -> u32 {
363 self.our_pid
364 }
365
366 pub fn get_lock_holder<P: AsRef<Path>>(db_path: P) -> Option<u32> {
370 let lock_path = db_path.as_ref().join(".lock");
371 let file = File::open(&lock_path).ok()?;
372 Self::read_pid(&file)
373 }
374
375 fn write_pid(file: &File, pid: u32) -> std::result::Result<(), LockError> {
377 use std::io::Seek;
378 let mut file = file;
379 file.seek(std::io::SeekFrom::Start(0))?;
380 file.set_len(0)?;
381 writeln!(file, "{}", pid)?;
382 file.sync_all()?;
383 Ok(())
384 }
385
386 fn read_pid(file: &File) -> Option<u32> {
388 use std::io::Seek;
389 let mut file = file;
390 let _ = file.seek(std::io::SeekFrom::Start(0));
391 let mut contents = String::new();
392 file.read_to_string(&mut contents).ok()?;
393 contents.trim().parse().ok()
394 }
395
396 #[cfg(unix)]
398 fn process_exists(pid: u32) -> bool {
399 let result = unsafe { libc::kill(pid as libc::pid_t, 0) };
402 if result == 0 {
403 true
404 } else {
405 let errno = std::io::Error::last_os_error().raw_os_error();
407 errno != Some(libc::ESRCH)
408 }
409 }
410
411 #[cfg(windows)]
412 fn process_exists(pid: u32) -> bool {
413 unsafe {
414 let handle = windows_sys::Win32::System::Threading::OpenProcess(
415 windows_sys::Win32::System::Threading::PROCESS_QUERY_LIMITED_INFORMATION,
416 0,
417 pid,
418 );
419 if handle == 0 || handle == -1 {
420 false
421 } else {
422 windows_sys::Win32::Foundation::CloseHandle(handle);
423 true
424 }
425 }
426 }
427
428 #[cfg(not(any(unix, windows)))]
429 fn process_exists(_pid: u32) -> bool {
430 true
432 }
433
434 #[cfg(unix)]
436 fn try_flock(file: &File, blocking: bool) -> std::result::Result<(), LockError> {
437 use std::os::unix::io::AsRawFd;
438
439 let fd = file.as_raw_fd();
440 let operation = if blocking {
441 libc::LOCK_EX
442 } else {
443 libc::LOCK_EX | libc::LOCK_NB
444 };
445
446 let result = unsafe { libc::flock(fd, operation) };
447
448 if result == 0 {
449 Ok(())
450 } else {
451 let err = std::io::Error::last_os_error();
452 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
453 Err(LockError::DatabaseLocked {
454 holder_pid: None,
455 lock_path: PathBuf::new(),
456 })
457 } else {
458 Err(LockError::Io(err))
459 }
460 }
461 }
462
463 #[cfg(windows)]
464 fn try_flock(file: &File, blocking: bool) -> std::result::Result<(), LockError> {
465 use std::os::windows::io::AsRawHandle;
466
467 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
468
469 let flags = windows_sys::Win32::Storage::FileSystem::LOCKFILE_EXCLUSIVE_LOCK
470 | if blocking { 0 } else { windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY };
471
472 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
473
474 let result = unsafe {
475 windows_sys::Win32::Storage::FileSystem::LockFileEx(
476 handle,
477 flags,
478 0,
479 1,
480 0,
481 &mut overlapped,
482 )
483 };
484
485 if result != 0 {
486 Ok(())
487 } else {
488 let err = std::io::Error::last_os_error();
489 if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
490 Err(LockError::DatabaseLocked {
491 holder_pid: None,
492 lock_path: PathBuf::new(),
493 })
494 } else {
495 Err(LockError::Io(err))
496 }
497 }
498 }
499
500 #[cfg(not(any(unix, windows)))]
501 fn try_flock(_file: &File, _blocking: bool) -> std::result::Result<(), LockError> {
502 Ok(())
505 }
506
507 #[cfg(unix)]
509 fn release(&self) {
510 use std::os::unix::io::AsRawFd;
511 let fd = self.lock_file.as_raw_fd();
512 unsafe { libc::flock(fd, libc::LOCK_UN) };
513 }
514
515 #[cfg(windows)]
516 fn release(&self) {
517 use std::os::windows::io::AsRawHandle;
518 let handle = self.lock_file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
519 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
520 unsafe {
521 windows_sys::Win32::Storage::FileSystem::UnlockFileEx(
522 handle,
523 0,
524 1,
525 0,
526 &mut overlapped,
527 );
528 }
529 }
530
531 #[cfg(not(any(unix, windows)))]
532 fn release(&self) {
533 }
535}
536
537impl Drop for DatabaseLock {
538 fn drop(&mut self) {
539 self.release();
540 let _ = std::fs::remove_file(&self.path);
543 }
544}
545
546#[repr(C)]
560#[derive(Debug, Clone, Copy, Default)]
561pub struct RwLockState {
562 pub reader_count: u32,
564 pub writer_intent: u32,
566 pub writer_active: u32,
568 pub _padding: u32,
570}
571
572#[derive(Debug, Clone, Copy, PartialEq, Eq)]
574pub enum ConnectionMode {
575 ReadOnly,
577 ReadWrite,
579}
580
581pub struct RwDatabaseLock {
588 lock_file: File,
590 path: PathBuf,
592 mode: ConnectionMode,
594 our_pid: u32,
596}
597
598impl RwDatabaseLock {
599 pub fn acquire_shared<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
604 Self::acquire_with_mode(db_path, ConnectionMode::ReadOnly, &LockConfig::default())
605 }
606
607 pub fn acquire_exclusive<P: AsRef<Path>>(db_path: P) -> std::result::Result<Self, LockError> {
612 Self::acquire_with_mode(db_path, ConnectionMode::ReadWrite, &LockConfig::default())
613 }
614
615 pub fn acquire_with_mode<P: AsRef<Path>>(
617 db_path: P,
618 mode: ConnectionMode,
619 config: &LockConfig,
620 ) -> std::result::Result<Self, LockError> {
621 let db_path = db_path.as_ref();
622 let lock_path = db_path.join(&config.lock_file_name);
623
624 if !db_path.exists() {
625 std::fs::create_dir_all(db_path)?;
626 }
627
628 let file = OpenOptions::new()
629 .create(true)
630 .read(true)
631 .write(true)
632 .open(&lock_path)?;
633
634 let our_pid = std::process::id();
635 let deadline = config.timeout.map(|t| Instant::now() + t);
636
637 loop {
638 match mode {
639 ConnectionMode::ReadOnly => {
640 if Self::try_shared_lock(&file)? {
642 return Ok(Self {
643 lock_file: file,
644 path: lock_path,
645 mode,
646 our_pid,
647 });
648 }
649 }
650 ConnectionMode::ReadWrite => {
651 if Self::try_exclusive_lock(&file)? {
653 return Ok(Self {
654 lock_file: file,
655 path: lock_path,
656 mode,
657 our_pid,
658 });
659 }
660 }
661 }
662
663 if let Some(deadline) = deadline {
665 if Instant::now() >= deadline {
666 return Err(LockError::Timeout {
667 elapsed: config.timeout.unwrap_or_default(),
668 timeout: config.timeout.unwrap_or_default(),
669 });
670 }
671 std::thread::sleep(config.retry_interval);
672 } else {
673 return Err(LockError::DatabaseLocked {
674 holder_pid: None,
675 lock_path,
676 });
677 }
678 }
679 }
680
681 pub fn mode(&self) -> ConnectionMode {
683 self.mode
684 }
685
686 pub fn is_readonly(&self) -> bool {
688 self.mode == ConnectionMode::ReadOnly
689 }
690
691 #[cfg(unix)]
692 fn try_shared_lock(file: &File) -> std::result::Result<bool, LockError> {
693 use std::os::unix::io::AsRawFd;
694 let fd = file.as_raw_fd();
695 let result = unsafe { libc::flock(fd, libc::LOCK_SH | libc::LOCK_NB) };
696 if result == 0 {
697 Ok(true)
698 } else {
699 let err = std::io::Error::last_os_error();
700 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
701 Ok(false)
702 } else {
703 Err(LockError::Io(err))
704 }
705 }
706 }
707
708 #[cfg(unix)]
709 fn try_exclusive_lock(file: &File) -> std::result::Result<bool, LockError> {
710 use std::os::unix::io::AsRawFd;
711 let fd = file.as_raw_fd();
712 let result = unsafe { libc::flock(fd, libc::LOCK_EX | libc::LOCK_NB) };
713 if result == 0 {
714 Ok(true)
715 } else {
716 let err = std::io::Error::last_os_error();
717 if err.raw_os_error() == Some(libc::EWOULDBLOCK) {
718 Ok(false)
719 } else {
720 Err(LockError::Io(err))
721 }
722 }
723 }
724
725 #[cfg(windows)]
726 fn try_shared_lock(file: &File) -> std::result::Result<bool, LockError> {
727 use std::os::windows::io::AsRawHandle;
728 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
729 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
730
731 let result = unsafe {
732 windows_sys::Win32::Storage::FileSystem::LockFileEx(
733 handle,
734 windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY,
735 0, 1, 0,
736 &mut overlapped,
737 )
738 };
739
740 if result != 0 {
741 Ok(true)
742 } else {
743 let err = std::io::Error::last_os_error();
744 if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
745 Ok(false)
746 } else {
747 Err(LockError::Io(err))
748 }
749 }
750 }
751
752 #[cfg(windows)]
753 fn try_exclusive_lock(file: &File) -> std::result::Result<bool, LockError> {
754 use std::os::windows::io::AsRawHandle;
755 let handle = file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
756 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
757
758 let result = unsafe {
759 windows_sys::Win32::Storage::FileSystem::LockFileEx(
760 handle,
761 windows_sys::Win32::Storage::FileSystem::LOCKFILE_EXCLUSIVE_LOCK
762 | windows_sys::Win32::Storage::FileSystem::LOCKFILE_FAIL_IMMEDIATELY,
763 0, 1, 0,
764 &mut overlapped,
765 )
766 };
767
768 if result != 0 {
769 Ok(true)
770 } else {
771 let err = std::io::Error::last_os_error();
772 if err.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION as i32) {
773 Ok(false)
774 } else {
775 Err(LockError::Io(err))
776 }
777 }
778 }
779
780 #[cfg(not(any(unix, windows)))]
781 fn try_shared_lock(_file: &File) -> std::result::Result<bool, LockError> {
782 Ok(true)
783 }
784
785 #[cfg(not(any(unix, windows)))]
786 fn try_exclusive_lock(_file: &File) -> std::result::Result<bool, LockError> {
787 Ok(true)
788 }
789
790 #[cfg(unix)]
791 fn release(&self) {
792 use std::os::unix::io::AsRawFd;
793 let fd = self.lock_file.as_raw_fd();
794 unsafe { libc::flock(fd, libc::LOCK_UN) };
795 }
796
797 #[cfg(windows)]
798 fn release(&self) {
799 use std::os::windows::io::AsRawHandle;
800 let handle = self.lock_file.as_raw_handle() as windows_sys::Win32::Foundation::HANDLE;
801 let mut overlapped: windows_sys::Win32::System::IO::OVERLAPPED = unsafe { std::mem::zeroed() };
802 unsafe {
803 windows_sys::Win32::Storage::FileSystem::UnlockFileEx(handle, 0, 1, 0, &mut overlapped);
804 }
805 }
806
807 #[cfg(not(any(unix, windows)))]
808 fn release(&self) {}
809}
810
811impl Drop for RwDatabaseLock {
812 fn drop(&mut self) {
813 self.release();
814 }
815}
816
817#[cfg(test)]
822mod tests {
823 use super::*;
824 use std::thread;
825 use tempfile::TempDir;
826
827 #[test]
828 fn test_exclusive_lock_basic() {
829 let dir = TempDir::new().unwrap();
830 let db_path = dir.path();
831
832 let lock1 = DatabaseLock::acquire(db_path);
834 assert!(lock1.is_ok());
835
836 let lock2 = DatabaseLock::acquire(db_path);
838 assert!(matches!(lock2, Err(LockError::DatabaseLocked { .. })));
839
840 drop(lock1);
842 let lock3 = DatabaseLock::acquire(db_path);
843 assert!(lock3.is_ok());
844 }
845
846 #[test]
847 fn test_lock_with_timeout() {
848 let dir = TempDir::new().unwrap();
849 let db_path = dir.path().to_path_buf();
850
851 let _lock = DatabaseLock::acquire(&db_path).unwrap();
853
854 let start = Instant::now();
856 let result = DatabaseLock::acquire_with_timeout(&db_path, Duration::from_millis(100));
857 let elapsed = start.elapsed();
858
859 assert!(matches!(result, Err(LockError::Timeout { .. })));
860 assert!(elapsed >= Duration::from_millis(100));
861 assert!(elapsed < Duration::from_millis(500)); }
863
864 #[test]
865 fn test_lock_pid_recorded() {
866 let dir = TempDir::new().unwrap();
867 let db_path = dir.path();
868
869 let lock = DatabaseLock::acquire(db_path).unwrap();
870 let our_pid = std::process::id();
871
872 assert_eq!(lock.pid(), our_pid);
873
874 let holder = DatabaseLock::get_lock_holder(db_path);
876 assert_eq!(holder, Some(our_pid));
877 }
878
879 #[test]
880 fn test_shared_lock_multiple_readers() {
881 let dir = TempDir::new().unwrap();
882 let db_path = dir.path();
883
884 let lock1 = RwDatabaseLock::acquire_shared(db_path);
886 let lock2 = RwDatabaseLock::acquire_shared(db_path);
887
888 assert!(lock1.is_ok());
889 assert!(lock2.is_ok());
890 }
891
892 #[test]
893 fn test_exclusive_blocks_shared() {
894 let dir = TempDir::new().unwrap();
895 let db_path = dir.path();
896
897 let _exclusive = RwDatabaseLock::acquire_exclusive(db_path).unwrap();
899
900 let shared = RwDatabaseLock::acquire_with_mode(
902 db_path,
903 ConnectionMode::ReadOnly,
904 &LockConfig::no_wait(),
905 );
906
907 assert!(matches!(shared, Err(LockError::DatabaseLocked { .. })));
908 }
909}