Skip to main content

tensorlogic_adapters/
locking.rs

1//! Multi-user schema management with locking.
2//!
3//! This module provides thread-safe concurrent access to symbol tables with
4//! read/write locking, transaction support, and lock statistics.
5//!
6//! # Features
7//!
8//! - **Read/Write Locks**: Multiple concurrent readers or single writer
9//! - **Transactions**: Atomic operations with commit/rollback support
10//! - **Lock Statistics**: Monitor lock contention and usage patterns
11//! - **Timeout Support**: Prevent indefinite blocking on lock acquisition
12//! - **Deadlock Detection**: Basic deadlock prevention through timeouts
13//!
14//! # Example
15//!
16//! ```rust
17//! use tensorlogic_adapters::{LockedSymbolTable, DomainInfo};
18//! use std::sync::Arc;
19//! use std::thread;
20//!
21//! let table = Arc::new(LockedSymbolTable::new());
22//!
23//! // Spawn multiple readers
24//! let mut handles = vec![];
25//! for i in 0..3 {
26//!     let table_clone = Arc::clone(&table);
27//!     handles.push(thread::spawn(move || {
28//!         let guard = table_clone.read();
29//!         println!("Reader {} sees {} domains", i, guard.domains.len());
30//!     }));
31//! }
32//!
33//! // Wait for readers
34//! for handle in handles {
35//!     handle.join().unwrap();
36//! }
37//!
38//! // Single writer
39//! {
40//!     let mut guard = table.write();
41//!     guard.add_domain(DomainInfo::new("User", 100)).unwrap();
42//! }
43//! ```
44
45use crate::{AdapterError, SymbolTable};
46use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockError};
47use std::time::{Duration, Instant};
48
49/// Statistics about lock usage and contention.
50#[derive(Debug, Clone, Default)]
51pub struct LockStats {
52    /// Total number of successful read lock acquisitions
53    pub read_locks: usize,
54    /// Total number of successful write lock acquisitions
55    pub write_locks: usize,
56    /// Total number of failed read lock attempts (would block)
57    pub read_contentions: usize,
58    /// Total number of failed write lock attempts (would block)
59    pub write_contentions: usize,
60    /// Total time spent waiting for read locks (milliseconds)
61    pub read_wait_ms: u128,
62    /// Total time spent waiting for write locks (milliseconds)
63    pub write_wait_ms: u128,
64    /// Number of transactions started
65    pub transactions_started: usize,
66    /// Number of transactions committed
67    pub transactions_committed: usize,
68    /// Number of transactions rolled back
69    pub transactions_rolled_back: usize,
70}
71
72impl LockStats {
73    /// Create new empty lock statistics.
74    pub fn new() -> Self {
75        Self::default()
76    }
77
78    /// Calculate average read wait time in milliseconds.
79    pub fn avg_read_wait_ms(&self) -> f64 {
80        if self.read_locks == 0 {
81            0.0
82        } else {
83            self.read_wait_ms as f64 / self.read_locks as f64
84        }
85    }
86
87    /// Calculate average write wait time in milliseconds.
88    pub fn avg_write_wait_ms(&self) -> f64 {
89        if self.write_locks == 0 {
90            0.0
91        } else {
92            self.write_wait_ms as f64 / self.write_locks as f64
93        }
94    }
95
96    /// Calculate read contention rate (0.0 to 1.0).
97    pub fn read_contention_rate(&self) -> f64 {
98        let total = self.read_locks + self.read_contentions;
99        if total == 0 {
100            0.0
101        } else {
102            self.read_contentions as f64 / total as f64
103        }
104    }
105
106    /// Calculate write contention rate (0.0 to 1.0).
107    pub fn write_contention_rate(&self) -> f64 {
108        let total = self.write_locks + self.write_contentions;
109        if total == 0 {
110            0.0
111        } else {
112            self.write_contentions as f64 / total as f64
113        }
114    }
115
116    /// Calculate transaction commit rate (0.0 to 1.0).
117    pub fn commit_rate(&self) -> f64 {
118        if self.transactions_started == 0 {
119            0.0
120        } else {
121            self.transactions_committed as f64 / self.transactions_started as f64
122        }
123    }
124}
125
126/// A thread-safe symbol table with read/write locking.
127///
128/// This wrapper provides concurrent access to a symbol table with read/write
129/// locks, transaction support, and lock statistics tracking.
130pub struct LockedSymbolTable {
131    table: RwLock<SymbolTable>,
132    stats: RwLock<LockStats>,
133}
134
135impl LockedSymbolTable {
136    /// Create a new locked symbol table.
137    pub fn new() -> Self {
138        Self {
139            table: RwLock::new(SymbolTable::new()),
140            stats: RwLock::new(LockStats::new()),
141        }
142    }
143
144    /// Create a locked symbol table from an existing symbol table.
145    pub fn from_table(table: SymbolTable) -> Self {
146        Self {
147            table: RwLock::new(table),
148            stats: RwLock::new(LockStats::new()),
149        }
150    }
151
152    /// Acquire a read lock on the symbol table.
153    ///
154    /// This will block until a read lock can be acquired. Multiple readers
155    /// can hold locks simultaneously.
156    pub fn read(&self) -> RwLockReadGuard<'_, SymbolTable> {
157        let start = Instant::now();
158        let guard = self.table.read().unwrap();
159        let elapsed = start.elapsed().as_millis();
160
161        if let Ok(mut stats) = self.stats.write() {
162            stats.read_locks += 1;
163            stats.read_wait_ms += elapsed;
164        }
165
166        guard
167    }
168
169    /// Try to acquire a read lock without blocking.
170    ///
171    /// Returns `Some(guard)` if successful, `None` if would block.
172    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, SymbolTable>> {
173        match self.table.try_read() {
174            Ok(guard) => {
175                if let Ok(mut stats) = self.stats.write() {
176                    stats.read_locks += 1;
177                }
178                Some(guard)
179            }
180            Err(TryLockError::WouldBlock) => {
181                if let Ok(mut stats) = self.stats.write() {
182                    stats.read_contentions += 1;
183                }
184                None
185            }
186            Err(TryLockError::Poisoned(_)) => None,
187        }
188    }
189
190    /// Acquire a write lock on the symbol table.
191    ///
192    /// This will block until a write lock can be acquired. Only one writer
193    /// can hold a lock at a time, and no readers can be active.
194    pub fn write(&self) -> RwLockWriteGuard<'_, SymbolTable> {
195        let start = Instant::now();
196        let guard = self.table.write().unwrap();
197        let elapsed = start.elapsed().as_millis();
198
199        if let Ok(mut stats) = self.stats.write() {
200            stats.write_locks += 1;
201            stats.write_wait_ms += elapsed;
202        }
203
204        guard
205    }
206
207    /// Try to acquire a write lock without blocking.
208    ///
209    /// Returns `Some(guard)` if successful, `None` if would block.
210    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, SymbolTable>> {
211        match self.table.try_write() {
212            Ok(guard) => {
213                if let Ok(mut stats) = self.stats.write() {
214                    stats.write_locks += 1;
215                }
216                Some(guard)
217            }
218            Err(TryLockError::WouldBlock) => {
219                if let Ok(mut stats) = self.stats.write() {
220                    stats.write_contentions += 1;
221                }
222                None
223            }
224            Err(TryLockError::Poisoned(_)) => None,
225        }
226    }
227
228    /// Get current lock statistics.
229    pub fn stats(&self) -> LockStats {
230        self.stats.read().unwrap().clone()
231    }
232
233    /// Reset lock statistics.
234    pub fn reset_stats(&self) {
235        *self.stats.write().unwrap() = LockStats::new();
236    }
237
238    /// Start a new transaction.
239    ///
240    /// Returns a transaction object that can be committed or rolled back.
241    pub fn begin_transaction(&self) -> Transaction<'_> {
242        if let Ok(mut stats) = self.stats.write() {
243            stats.transactions_started += 1;
244        }
245        Transaction::new(self)
246    }
247}
248
249impl Default for LockedSymbolTable {
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255/// A transaction for atomic operations on a symbol table.
256///
257/// Transactions capture the state of the symbol table at the start and
258/// allow rolling back to that state if needed.
259pub struct Transaction<'a> {
260    locked_table: &'a LockedSymbolTable,
261    snapshot: Option<SymbolTable>,
262    committed: bool,
263}
264
265impl<'a> Transaction<'a> {
266    fn new(locked_table: &'a LockedSymbolTable) -> Self {
267        // Take snapshot
268        let snapshot = locked_table.read().clone();
269        Self {
270            locked_table,
271            snapshot: Some(snapshot),
272            committed: false,
273        }
274    }
275
276    /// Execute operations within this transaction.
277    ///
278    /// The closure receives a mutable reference to the symbol table.
279    pub fn execute<F, R>(&mut self, f: F) -> Result<R, AdapterError>
280    where
281        F: FnOnce(&mut SymbolTable) -> Result<R, AdapterError>,
282    {
283        let mut guard = self.locked_table.write();
284        f(&mut guard)
285    }
286
287    /// Commit the transaction, making all changes permanent.
288    pub fn commit(mut self) {
289        self.committed = true;
290        if let Ok(mut stats) = self.locked_table.stats.write() {
291            stats.transactions_committed += 1;
292        }
293        // Drop snapshot
294        self.snapshot = None;
295    }
296
297    /// Rollback the transaction, reverting all changes.
298    pub fn rollback(mut self) {
299        if let Some(snapshot) = self.snapshot.take() {
300            *self.locked_table.write() = snapshot;
301        }
302        if let Ok(mut stats) = self.locked_table.stats.write() {
303            stats.transactions_rolled_back += 1;
304        }
305    }
306}
307
308impl<'a> Drop for Transaction<'a> {
309    fn drop(&mut self) {
310        // Auto-rollback if not committed
311        if !self.committed {
312            if let Some(snapshot) = self.snapshot.take() {
313                if let Ok(mut guard) = self.locked_table.table.write() {
314                    *guard = snapshot;
315                }
316                if let Ok(mut stats) = self.locked_table.stats.write() {
317                    stats.transactions_rolled_back += 1;
318                }
319            }
320        }
321    }
322}
323
324/// Extension trait for timeout-based lock acquisition.
325pub trait LockWithTimeout {
326    /// Try to acquire a read lock with a timeout.
327    ///
328    /// Returns `Some(guard)` if successful within timeout, `None` otherwise.
329    fn read_timeout(&self, timeout: Duration) -> Option<RwLockReadGuard<'_, SymbolTable>>;
330
331    /// Try to acquire a write lock with a timeout.
332    ///
333    /// Returns `Some(guard)` if successful within timeout, `None` otherwise.
334    fn write_timeout(&self, timeout: Duration) -> Option<RwLockWriteGuard<'_, SymbolTable>>;
335}
336
337impl LockWithTimeout for LockedSymbolTable {
338    fn read_timeout(&self, timeout: Duration) -> Option<RwLockReadGuard<'_, SymbolTable>> {
339        let start = Instant::now();
340        loop {
341            if let Some(guard) = self.try_read() {
342                return Some(guard);
343            }
344            if start.elapsed() >= timeout {
345                if let Ok(mut stats) = self.stats.write() {
346                    stats.read_contentions += 1;
347                }
348                return None;
349            }
350            std::thread::sleep(Duration::from_millis(1));
351        }
352    }
353
354    fn write_timeout(&self, timeout: Duration) -> Option<RwLockWriteGuard<'_, SymbolTable>> {
355        let start = Instant::now();
356        loop {
357            if let Some(guard) = self.try_write() {
358                return Some(guard);
359            }
360            if start.elapsed() >= timeout {
361                if let Ok(mut stats) = self.stats.write() {
362                    stats.write_contentions += 1;
363                }
364                return None;
365            }
366            std::thread::sleep(Duration::from_millis(1));
367        }
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::DomainInfo;
375    use std::sync::Arc;
376    use std::thread;
377
378    #[test]
379    fn test_basic_read_write() {
380        let table = LockedSymbolTable::new();
381
382        // Write
383        {
384            let mut guard = table.write();
385            guard.add_domain(DomainInfo::new("User", 100)).unwrap();
386        }
387
388        // Read
389        {
390            let guard = table.read();
391            assert_eq!(guard.domains.len(), 1);
392            assert!(guard.get_domain("User").is_some());
393        }
394    }
395
396    #[test]
397    fn test_multiple_readers() {
398        let table = Arc::new(LockedSymbolTable::new());
399
400        // Add some data
401        {
402            let mut guard = table.write();
403            guard.add_domain(DomainInfo::new("User", 100)).unwrap();
404        }
405
406        // Spawn multiple readers
407        let mut handles = vec![];
408        for _ in 0..5 {
409            let table_clone = Arc::clone(&table);
410            handles.push(thread::spawn(move || {
411                let guard = table_clone.read();
412                assert_eq!(guard.domains.len(), 1);
413            }));
414        }
415
416        for handle in handles {
417            handle.join().unwrap();
418        }
419    }
420
421    #[test]
422    fn test_try_read_write() {
423        let table = LockedSymbolTable::new();
424
425        // Try read (should succeed)
426        {
427            let guard = table.try_read();
428            assert!(guard.is_some());
429        }
430
431        // Try write (should succeed)
432        {
433            let guard = table.try_write();
434            assert!(guard.is_some());
435        }
436    }
437
438    #[test]
439    fn test_try_write_contention() {
440        let table = Arc::new(LockedSymbolTable::new());
441
442        // Hold read lock
443        let _read_guard = table.read();
444
445        // Try write (should fail due to active reader)
446        let table_clone = Arc::clone(&table);
447        let handle = thread::spawn(move || {
448            let guard = table_clone.try_write();
449            assert!(guard.is_none());
450        });
451
452        handle.join().unwrap();
453
454        // Check contention stats
455        let stats = table.stats();
456        assert!(stats.write_contentions > 0);
457    }
458
459    #[test]
460    fn test_transaction_commit() {
461        let table = LockedSymbolTable::new();
462
463        {
464            let mut txn = table.begin_transaction();
465            txn.execute(|t| {
466                t.add_domain(DomainInfo::new("User", 100)).unwrap();
467                t.add_domain(DomainInfo::new("Post", 1000)).unwrap();
468                Ok(())
469            })
470            .unwrap();
471            txn.commit();
472        }
473
474        let guard = table.read();
475        assert_eq!(guard.domains.len(), 2);
476
477        let stats = table.stats();
478        assert_eq!(stats.transactions_committed, 1);
479    }
480
481    #[test]
482    fn test_transaction_rollback() {
483        let table = LockedSymbolTable::new();
484
485        // Add initial domain
486        {
487            let mut guard = table.write();
488            guard.add_domain(DomainInfo::new("User", 100)).unwrap();
489        }
490
491        {
492            let mut txn = table.begin_transaction();
493            txn.execute(|t| {
494                t.add_domain(DomainInfo::new("Post", 1000)).unwrap();
495                Ok(())
496            })
497            .unwrap();
498            txn.rollback();
499        }
500
501        let guard = table.read();
502        assert_eq!(guard.domains.len(), 1);
503        assert!(guard.get_domain("Post").is_none());
504
505        let stats = table.stats();
506        assert_eq!(stats.transactions_rolled_back, 1);
507    }
508
509    #[test]
510    fn test_transaction_auto_rollback() {
511        let table = LockedSymbolTable::new();
512
513        {
514            let mut txn = table.begin_transaction();
515            txn.execute(|t| {
516                t.add_domain(DomainInfo::new("User", 100)).unwrap();
517                Ok(())
518            })
519            .unwrap();
520            // Drop without commit (auto-rollback)
521        }
522
523        let guard = table.read();
524        assert_eq!(guard.domains.len(), 0);
525
526        let stats = table.stats();
527        assert_eq!(stats.transactions_rolled_back, 1);
528    }
529
530    #[test]
531    fn test_lock_stats() {
532        let table = LockedSymbolTable::new();
533
534        // Read operations
535        for _ in 0..3 {
536            let _guard = table.read();
537        }
538
539        // Write operations
540        for _ in 0..2 {
541            let _guard = table.write();
542        }
543
544        let stats = table.stats();
545        assert_eq!(stats.read_locks, 3);
546        assert_eq!(stats.write_locks, 2);
547    }
548
549    #[test]
550    fn test_reset_stats() {
551        let table = LockedSymbolTable::new();
552
553        let _guard = table.read();
554        assert_eq!(table.stats().read_locks, 1);
555
556        table.reset_stats();
557        assert_eq!(table.stats().read_locks, 0);
558    }
559
560    #[test]
561    fn test_timeout_success() {
562        let table = LockedSymbolTable::new();
563
564        let guard = table.read_timeout(Duration::from_millis(100));
565        assert!(guard.is_some());
566    }
567
568    #[test]
569    fn test_timeout_failure() {
570        let table = Arc::new(LockedSymbolTable::new());
571
572        // Hold write lock
573        let _write_guard = table.write();
574
575        // Try to acquire write lock with timeout in another thread
576        let table_clone = Arc::clone(&table);
577        let handle = thread::spawn(move || {
578            let guard = table_clone.write_timeout(Duration::from_millis(50));
579            assert!(guard.is_none());
580        });
581
582        handle.join().unwrap();
583    }
584
585    #[test]
586    fn test_concurrent_read_write() {
587        let table = Arc::new(LockedSymbolTable::new());
588
589        // Initialize with data
590        {
591            let mut guard = table.write();
592            guard.add_domain(DomainInfo::new("User", 100)).unwrap();
593        }
594
595        let mut handles = vec![];
596
597        // Readers
598        for _ in 0..3 {
599            let table_clone = Arc::clone(&table);
600            handles.push(thread::spawn(move || {
601                for _ in 0..10 {
602                    let guard = table_clone.read();
603                    assert!(!guard.domains.is_empty());
604                    thread::sleep(Duration::from_millis(1));
605                }
606            }));
607        }
608
609        // Writers
610        for i in 0..2 {
611            let table_clone = Arc::clone(&table);
612            handles.push(thread::spawn(move || {
613                for j in 0..5 {
614                    let mut guard = table_clone.write();
615                    let domain_name = format!("Domain_{}_{}", i, j);
616                    guard
617                        .add_domain(DomainInfo::new(&domain_name, 100))
618                        .unwrap();
619                    thread::sleep(Duration::from_millis(2));
620                }
621            }));
622        }
623
624        for handle in handles {
625            handle.join().unwrap();
626        }
627
628        // Verify final state
629        let guard = table.read();
630        assert!(guard.domains.len() >= 11); // 1 initial + 10 from writers
631
632        // Check stats
633        let stats = table.stats();
634        assert!(stats.read_locks > 0);
635        assert!(stats.write_locks > 0);
636    }
637
638    #[test]
639    fn test_stats_calculations() {
640        let mut stats = LockStats::new();
641        stats.read_locks = 10;
642        stats.write_locks = 5;
643        stats.read_wait_ms = 100;
644        stats.write_wait_ms = 200;
645        stats.read_contentions = 2;
646        stats.write_contentions = 3;
647        stats.transactions_started = 10;
648        stats.transactions_committed = 8;
649
650        assert_eq!(stats.avg_read_wait_ms(), 10.0);
651        assert_eq!(stats.avg_write_wait_ms(), 40.0);
652        assert!((stats.read_contention_rate() - 0.1667).abs() < 0.001);
653        assert_eq!(stats.write_contention_rate(), 0.375);
654        assert_eq!(stats.commit_rate(), 0.8);
655    }
656
657    #[test]
658    fn test_transaction_error_handling() {
659        let table = LockedSymbolTable::new();
660
661        let result: Result<(), AdapterError> = {
662            let mut txn = table.begin_transaction();
663            txn.execute(|t| {
664                t.add_domain(DomainInfo::new("User", 100)).unwrap();
665                // Simulate error
666                Err(AdapterError::DuplicateDomain("User".to_string()))
667            })
668        };
669
670        assert!(result.is_err());
671
672        // Transaction should auto-rollback
673        let guard = table.read();
674        assert_eq!(guard.domains.len(), 0);
675    }
676
677    #[test]
678    fn test_from_table() {
679        let mut original = SymbolTable::new();
680        original.add_domain(DomainInfo::new("User", 100)).unwrap();
681
682        let locked = LockedSymbolTable::from_table(original);
683
684        let guard = locked.read();
685        assert_eq!(guard.domains.len(), 1);
686        assert!(guard.get_domain("User").is_some());
687    }
688}