Skip to main content

tensorlogic_adapters/
locking.rs

1//! Multi-user schema management with locking.
2//!
3//! This module provides thread-safe concurrent access to symbol tables with
4//! read/write locking, transaction support, and lock statistics.
5//!
6//! # Features
7//!
8//! - **Read/Write Locks**: Multiple concurrent readers or single writer
9//! - **Transactions**: Atomic operations with commit/rollback support
10//! - **Lock Statistics**: Monitor lock contention and usage patterns
11//! - **Timeout Support**: Prevent indefinite blocking on lock acquisition
12//! - **Deadlock Detection**: Basic deadlock prevention through timeouts
13//!
14//! # Example
15//!
16//! ```rust
17//! use tensorlogic_adapters::{LockedSymbolTable, DomainInfo};
18//! use std::sync::Arc;
19//! use std::thread;
20//!
21//! let table = Arc::new(LockedSymbolTable::new());
22//!
23//! // Spawn multiple readers
24//! let mut handles = vec![];
25//! for i in 0..3 {
26//!     let table_clone = Arc::clone(&table);
27//!     handles.push(thread::spawn(move || {
28//!         let guard = table_clone.read();
29//!         println!("Reader {} sees {} domains", i, guard.domains.len());
30//!     }));
31//! }
32//!
33//! // Wait for readers
34//! for handle in handles {
35//!     handle.join().expect("unwrap");
36//! }
37//!
38//! // Single writer
39//! {
40//!     let mut guard = table.write();
41//!     guard.add_domain(DomainInfo::new("User", 100)).expect("unwrap");
42//! }
43//! ```
44
45use crate::{AdapterError, SymbolTable};
46use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockError};
47use std::time::{Duration, Instant};
48
49/// Statistics about lock usage and contention.
50#[derive(Debug, Clone, Default)]
51pub struct LockStats {
52    /// Total number of successful read lock acquisitions
53    pub read_locks: usize,
54    /// Total number of successful write lock acquisitions
55    pub write_locks: usize,
56    /// Total number of failed read lock attempts (would block)
57    pub read_contentions: usize,
58    /// Total number of failed write lock attempts (would block)
59    pub write_contentions: usize,
60    /// Total time spent waiting for read locks (milliseconds)
61    pub read_wait_ms: u128,
62    /// Total time spent waiting for write locks (milliseconds)
63    pub write_wait_ms: u128,
64    /// Number of transactions started
65    pub transactions_started: usize,
66    /// Number of transactions committed
67    pub transactions_committed: usize,
68    /// Number of transactions rolled back
69    pub transactions_rolled_back: usize,
70}
71
72impl LockStats {
73    /// Create new empty lock statistics.
74    pub fn new() -> Self {
75        Self::default()
76    }
77
78    /// Calculate average read wait time in milliseconds.
79    pub fn avg_read_wait_ms(&self) -> f64 {
80        if self.read_locks == 0 {
81            0.0
82        } else {
83            self.read_wait_ms as f64 / self.read_locks as f64
84        }
85    }
86
87    /// Calculate average write wait time in milliseconds.
88    pub fn avg_write_wait_ms(&self) -> f64 {
89        if self.write_locks == 0 {
90            0.0
91        } else {
92            self.write_wait_ms as f64 / self.write_locks as f64
93        }
94    }
95
96    /// Calculate read contention rate (0.0 to 1.0).
97    pub fn read_contention_rate(&self) -> f64 {
98        let total = self.read_locks + self.read_contentions;
99        if total == 0 {
100            0.0
101        } else {
102            self.read_contentions as f64 / total as f64
103        }
104    }
105
106    /// Calculate write contention rate (0.0 to 1.0).
107    pub fn write_contention_rate(&self) -> f64 {
108        let total = self.write_locks + self.write_contentions;
109        if total == 0 {
110            0.0
111        } else {
112            self.write_contentions as f64 / total as f64
113        }
114    }
115
116    /// Calculate transaction commit rate (0.0 to 1.0).
117    pub fn commit_rate(&self) -> f64 {
118        if self.transactions_started == 0 {
119            0.0
120        } else {
121            self.transactions_committed as f64 / self.transactions_started as f64
122        }
123    }
124}
125
126/// A thread-safe symbol table with read/write locking.
127///
128/// This wrapper provides concurrent access to a symbol table with read/write
129/// locks, transaction support, and lock statistics tracking.
130pub struct LockedSymbolTable {
131    table: RwLock<SymbolTable>,
132    stats: RwLock<LockStats>,
133}
134
135impl LockedSymbolTable {
136    /// Create a new locked symbol table.
137    pub fn new() -> Self {
138        Self {
139            table: RwLock::new(SymbolTable::new()),
140            stats: RwLock::new(LockStats::new()),
141        }
142    }
143
144    /// Create a locked symbol table from an existing symbol table.
145    pub fn from_table(table: SymbolTable) -> Self {
146        Self {
147            table: RwLock::new(table),
148            stats: RwLock::new(LockStats::new()),
149        }
150    }
151
152    /// Acquire a read lock on the symbol table.
153    ///
154    /// This will block until a read lock can be acquired. Multiple readers
155    /// can hold locks simultaneously.
156    pub fn read(&self) -> RwLockReadGuard<'_, SymbolTable> {
157        let start = Instant::now();
158        let guard = self.table.read().expect("lock should not be poisoned");
159        let elapsed = start.elapsed().as_millis();
160
161        if let Ok(mut stats) = self.stats.write() {
162            stats.read_locks += 1;
163            stats.read_wait_ms += elapsed;
164        }
165
166        guard
167    }
168
169    /// Try to acquire a read lock without blocking.
170    ///
171    /// Returns `Some(guard)` if successful, `None` if would block.
172    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, SymbolTable>> {
173        match self.table.try_read() {
174            Ok(guard) => {
175                if let Ok(mut stats) = self.stats.write() {
176                    stats.read_locks += 1;
177                }
178                Some(guard)
179            }
180            Err(TryLockError::WouldBlock) => {
181                if let Ok(mut stats) = self.stats.write() {
182                    stats.read_contentions += 1;
183                }
184                None
185            }
186            Err(TryLockError::Poisoned(_)) => None,
187        }
188    }
189
190    /// Acquire a write lock on the symbol table.
191    ///
192    /// This will block until a write lock can be acquired. Only one writer
193    /// can hold a lock at a time, and no readers can be active.
194    pub fn write(&self) -> RwLockWriteGuard<'_, SymbolTable> {
195        let start = Instant::now();
196        let guard = self.table.write().expect("lock should not be poisoned");
197        let elapsed = start.elapsed().as_millis();
198
199        if let Ok(mut stats) = self.stats.write() {
200            stats.write_locks += 1;
201            stats.write_wait_ms += elapsed;
202        }
203
204        guard
205    }
206
207    /// Try to acquire a write lock without blocking.
208    ///
209    /// Returns `Some(guard)` if successful, `None` if would block.
210    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, SymbolTable>> {
211        match self.table.try_write() {
212            Ok(guard) => {
213                if let Ok(mut stats) = self.stats.write() {
214                    stats.write_locks += 1;
215                }
216                Some(guard)
217            }
218            Err(TryLockError::WouldBlock) => {
219                if let Ok(mut stats) = self.stats.write() {
220                    stats.write_contentions += 1;
221                }
222                None
223            }
224            Err(TryLockError::Poisoned(_)) => None,
225        }
226    }
227
228    /// Get current lock statistics.
229    pub fn stats(&self) -> LockStats {
230        self.stats
231            .read()
232            .expect("lock should not be poisoned")
233            .clone()
234    }
235
236    /// Reset lock statistics.
237    pub fn reset_stats(&self) {
238        *self.stats.write().expect("lock should not be poisoned") = LockStats::new();
239    }
240
241    /// Start a new transaction.
242    ///
243    /// Returns a transaction object that can be committed or rolled back.
244    pub fn begin_transaction(&self) -> Transaction<'_> {
245        if let Ok(mut stats) = self.stats.write() {
246            stats.transactions_started += 1;
247        }
248        Transaction::new(self)
249    }
250}
251
252impl Default for LockedSymbolTable {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258/// A transaction for atomic operations on a symbol table.
259///
260/// Transactions capture the state of the symbol table at the start and
261/// allow rolling back to that state if needed.
262pub struct Transaction<'a> {
263    locked_table: &'a LockedSymbolTable,
264    snapshot: Option<SymbolTable>,
265    committed: bool,
266}
267
268impl<'a> Transaction<'a> {
269    fn new(locked_table: &'a LockedSymbolTable) -> Self {
270        // Take snapshot
271        let snapshot = locked_table.read().clone();
272        Self {
273            locked_table,
274            snapshot: Some(snapshot),
275            committed: false,
276        }
277    }
278
279    /// Execute operations within this transaction.
280    ///
281    /// The closure receives a mutable reference to the symbol table.
282    pub fn execute<F, R>(&mut self, f: F) -> Result<R, AdapterError>
283    where
284        F: FnOnce(&mut SymbolTable) -> Result<R, AdapterError>,
285    {
286        let mut guard = self.locked_table.write();
287        f(&mut guard)
288    }
289
290    /// Commit the transaction, making all changes permanent.
291    pub fn commit(mut self) {
292        self.committed = true;
293        if let Ok(mut stats) = self.locked_table.stats.write() {
294            stats.transactions_committed += 1;
295        }
296        // Drop snapshot
297        self.snapshot = None;
298    }
299
300    /// Rollback the transaction, reverting all changes.
301    pub fn rollback(mut self) {
302        if let Some(snapshot) = self.snapshot.take() {
303            *self.locked_table.write() = snapshot;
304        }
305        if let Ok(mut stats) = self.locked_table.stats.write() {
306            stats.transactions_rolled_back += 1;
307        }
308    }
309}
310
311impl<'a> Drop for Transaction<'a> {
312    fn drop(&mut self) {
313        // Auto-rollback if not committed
314        if !self.committed {
315            if let Some(snapshot) = self.snapshot.take() {
316                if let Ok(mut guard) = self.locked_table.table.write() {
317                    *guard = snapshot;
318                }
319                if let Ok(mut stats) = self.locked_table.stats.write() {
320                    stats.transactions_rolled_back += 1;
321                }
322            }
323        }
324    }
325}
326
327/// Extension trait for timeout-based lock acquisition.
328pub trait LockWithTimeout {
329    /// Try to acquire a read lock with a timeout.
330    ///
331    /// Returns `Some(guard)` if successful within timeout, `None` otherwise.
332    fn read_timeout(&self, timeout: Duration) -> Option<RwLockReadGuard<'_, SymbolTable>>;
333
334    /// Try to acquire a write lock with a timeout.
335    ///
336    /// Returns `Some(guard)` if successful within timeout, `None` otherwise.
337    fn write_timeout(&self, timeout: Duration) -> Option<RwLockWriteGuard<'_, SymbolTable>>;
338}
339
340impl LockWithTimeout for LockedSymbolTable {
341    fn read_timeout(&self, timeout: Duration) -> Option<RwLockReadGuard<'_, SymbolTable>> {
342        let start = Instant::now();
343        loop {
344            if let Some(guard) = self.try_read() {
345                return Some(guard);
346            }
347            if start.elapsed() >= timeout {
348                if let Ok(mut stats) = self.stats.write() {
349                    stats.read_contentions += 1;
350                }
351                return None;
352            }
353            std::thread::sleep(Duration::from_millis(1));
354        }
355    }
356
357    fn write_timeout(&self, timeout: Duration) -> Option<RwLockWriteGuard<'_, SymbolTable>> {
358        let start = Instant::now();
359        loop {
360            if let Some(guard) = self.try_write() {
361                return Some(guard);
362            }
363            if start.elapsed() >= timeout {
364                if let Ok(mut stats) = self.stats.write() {
365                    stats.write_contentions += 1;
366                }
367                return None;
368            }
369            std::thread::sleep(Duration::from_millis(1));
370        }
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::DomainInfo;
378    use std::sync::Arc;
379    use std::thread;
380
381    #[test]
382    fn test_basic_read_write() {
383        let table = LockedSymbolTable::new();
384
385        // Write
386        {
387            let mut guard = table.write();
388            guard
389                .add_domain(DomainInfo::new("User", 100))
390                .expect("unwrap");
391        }
392
393        // Read
394        {
395            let guard = table.read();
396            assert_eq!(guard.domains.len(), 1);
397            assert!(guard.get_domain("User").is_some());
398        }
399    }
400
401    #[test]
402    fn test_multiple_readers() {
403        let table = Arc::new(LockedSymbolTable::new());
404
405        // Add some data
406        {
407            let mut guard = table.write();
408            guard
409                .add_domain(DomainInfo::new("User", 100))
410                .expect("unwrap");
411        }
412
413        // Spawn multiple readers
414        let mut handles = vec![];
415        for _ in 0..5 {
416            let table_clone = Arc::clone(&table);
417            handles.push(thread::spawn(move || {
418                let guard = table_clone.read();
419                assert_eq!(guard.domains.len(), 1);
420            }));
421        }
422
423        for handle in handles {
424            handle.join().expect("unwrap");
425        }
426    }
427
428    #[test]
429    fn test_try_read_write() {
430        let table = LockedSymbolTable::new();
431
432        // Try read (should succeed)
433        {
434            let guard = table.try_read();
435            assert!(guard.is_some());
436        }
437
438        // Try write (should succeed)
439        {
440            let guard = table.try_write();
441            assert!(guard.is_some());
442        }
443    }
444
445    #[test]
446    fn test_try_write_contention() {
447        let table = Arc::new(LockedSymbolTable::new());
448
449        // Hold read lock
450        let _read_guard = table.read();
451
452        // Try write (should fail due to active reader)
453        let table_clone = Arc::clone(&table);
454        let handle = thread::spawn(move || {
455            let guard = table_clone.try_write();
456            assert!(guard.is_none());
457        });
458
459        handle.join().expect("unwrap");
460
461        // Check contention stats
462        let stats = table.stats();
463        assert!(stats.write_contentions > 0);
464    }
465
466    #[test]
467    fn test_transaction_commit() {
468        let table = LockedSymbolTable::new();
469
470        {
471            let mut txn = table.begin_transaction();
472            txn.execute(|t| {
473                t.add_domain(DomainInfo::new("User", 100)).expect("unwrap");
474                t.add_domain(DomainInfo::new("Post", 1000)).expect("unwrap");
475                Ok(())
476            })
477            .expect("unwrap");
478            txn.commit();
479        }
480
481        let guard = table.read();
482        assert_eq!(guard.domains.len(), 2);
483
484        let stats = table.stats();
485        assert_eq!(stats.transactions_committed, 1);
486    }
487
488    #[test]
489    fn test_transaction_rollback() {
490        let table = LockedSymbolTable::new();
491
492        // Add initial domain
493        {
494            let mut guard = table.write();
495            guard
496                .add_domain(DomainInfo::new("User", 100))
497                .expect("unwrap");
498        }
499
500        {
501            let mut txn = table.begin_transaction();
502            txn.execute(|t| {
503                t.add_domain(DomainInfo::new("Post", 1000)).expect("unwrap");
504                Ok(())
505            })
506            .expect("unwrap");
507            txn.rollback();
508        }
509
510        let guard = table.read();
511        assert_eq!(guard.domains.len(), 1);
512        assert!(guard.get_domain("Post").is_none());
513
514        let stats = table.stats();
515        assert_eq!(stats.transactions_rolled_back, 1);
516    }
517
518    #[test]
519    fn test_transaction_auto_rollback() {
520        let table = LockedSymbolTable::new();
521
522        {
523            let mut txn = table.begin_transaction();
524            txn.execute(|t| {
525                t.add_domain(DomainInfo::new("User", 100)).expect("unwrap");
526                Ok(())
527            })
528            .expect("unwrap");
529            // Drop without commit (auto-rollback)
530        }
531
532        let guard = table.read();
533        assert_eq!(guard.domains.len(), 0);
534
535        let stats = table.stats();
536        assert_eq!(stats.transactions_rolled_back, 1);
537    }
538
539    #[test]
540    fn test_lock_stats() {
541        let table = LockedSymbolTable::new();
542
543        // Read operations
544        for _ in 0..3 {
545            let _guard = table.read();
546        }
547
548        // Write operations
549        for _ in 0..2 {
550            let _guard = table.write();
551        }
552
553        let stats = table.stats();
554        assert_eq!(stats.read_locks, 3);
555        assert_eq!(stats.write_locks, 2);
556    }
557
558    #[test]
559    fn test_reset_stats() {
560        let table = LockedSymbolTable::new();
561
562        let _guard = table.read();
563        assert_eq!(table.stats().read_locks, 1);
564
565        table.reset_stats();
566        assert_eq!(table.stats().read_locks, 0);
567    }
568
569    #[test]
570    fn test_timeout_success() {
571        let table = LockedSymbolTable::new();
572
573        let guard = table.read_timeout(Duration::from_millis(100));
574        assert!(guard.is_some());
575    }
576
577    #[test]
578    fn test_timeout_failure() {
579        let table = Arc::new(LockedSymbolTable::new());
580
581        // Hold write lock
582        let _write_guard = table.write();
583
584        // Try to acquire write lock with timeout in another thread
585        let table_clone = Arc::clone(&table);
586        let handle = thread::spawn(move || {
587            let guard = table_clone.write_timeout(Duration::from_millis(50));
588            assert!(guard.is_none());
589        });
590
591        handle.join().expect("unwrap");
592    }
593
594    #[test]
595    fn test_concurrent_read_write() {
596        let table = Arc::new(LockedSymbolTable::new());
597
598        // Initialize with data
599        {
600            let mut guard = table.write();
601            guard
602                .add_domain(DomainInfo::new("User", 100))
603                .expect("unwrap");
604        }
605
606        let mut handles = vec![];
607
608        // Readers
609        for _ in 0..3 {
610            let table_clone = Arc::clone(&table);
611            handles.push(thread::spawn(move || {
612                for _ in 0..10 {
613                    let guard = table_clone.read();
614                    assert!(!guard.domains.is_empty());
615                    thread::sleep(Duration::from_millis(1));
616                }
617            }));
618        }
619
620        // Writers
621        for i in 0..2 {
622            let table_clone = Arc::clone(&table);
623            handles.push(thread::spawn(move || {
624                for j in 0..5 {
625                    let mut guard = table_clone.write();
626                    let domain_name = format!("Domain_{}_{}", i, j);
627                    guard
628                        .add_domain(DomainInfo::new(&domain_name, 100))
629                        .expect("unwrap");
630                    thread::sleep(Duration::from_millis(2));
631                }
632            }));
633        }
634
635        for handle in handles {
636            handle.join().expect("unwrap");
637        }
638
639        // Verify final state
640        let guard = table.read();
641        assert!(guard.domains.len() >= 11); // 1 initial + 10 from writers
642
643        // Check stats
644        let stats = table.stats();
645        assert!(stats.read_locks > 0);
646        assert!(stats.write_locks > 0);
647    }
648
649    #[test]
650    fn test_stats_calculations() {
651        let mut stats = LockStats::new();
652        stats.read_locks = 10;
653        stats.write_locks = 5;
654        stats.read_wait_ms = 100;
655        stats.write_wait_ms = 200;
656        stats.read_contentions = 2;
657        stats.write_contentions = 3;
658        stats.transactions_started = 10;
659        stats.transactions_committed = 8;
660
661        assert_eq!(stats.avg_read_wait_ms(), 10.0);
662        assert_eq!(stats.avg_write_wait_ms(), 40.0);
663        assert!((stats.read_contention_rate() - 0.1667).abs() < 0.001);
664        assert_eq!(stats.write_contention_rate(), 0.375);
665        assert_eq!(stats.commit_rate(), 0.8);
666    }
667
668    #[test]
669    fn test_transaction_error_handling() {
670        let table = LockedSymbolTable::new();
671
672        let result: Result<(), AdapterError> = {
673            let mut txn = table.begin_transaction();
674            txn.execute(|t| {
675                t.add_domain(DomainInfo::new("User", 100)).expect("unwrap");
676                // Simulate error
677                Err(AdapterError::DuplicateDomain("User".to_string()))
678            })
679        };
680
681        assert!(result.is_err());
682
683        // Transaction should auto-rollback
684        let guard = table.read();
685        assert_eq!(guard.domains.len(), 0);
686    }
687
688    #[test]
689    fn test_from_table() {
690        let mut original = SymbolTable::new();
691        original
692            .add_domain(DomainInfo::new("User", 100))
693            .expect("unwrap");
694
695        let locked = LockedSymbolTable::from_table(original);
696
697        let guard = locked.read();
698        assert_eq!(guard.domains.len(), 1);
699        assert!(guard.get_domain("User").is_some());
700    }
701}