tensorlogic_adapters/
locking.rs1use crate::{AdapterError, SymbolTable};
46use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockError};
47use std::time::{Duration, Instant};
48
49#[derive(Debug, Clone, Default)]
51pub struct LockStats {
52 pub read_locks: usize,
54 pub write_locks: usize,
56 pub read_contentions: usize,
58 pub write_contentions: usize,
60 pub read_wait_ms: u128,
62 pub write_wait_ms: u128,
64 pub transactions_started: usize,
66 pub transactions_committed: usize,
68 pub transactions_rolled_back: usize,
70}
71
72impl LockStats {
73 pub fn new() -> Self {
75 Self::default()
76 }
77
78 pub fn avg_read_wait_ms(&self) -> f64 {
80 if self.read_locks == 0 {
81 0.0
82 } else {
83 self.read_wait_ms as f64 / self.read_locks as f64
84 }
85 }
86
87 pub fn avg_write_wait_ms(&self) -> f64 {
89 if self.write_locks == 0 {
90 0.0
91 } else {
92 self.write_wait_ms as f64 / self.write_locks as f64
93 }
94 }
95
96 pub fn read_contention_rate(&self) -> f64 {
98 let total = self.read_locks + self.read_contentions;
99 if total == 0 {
100 0.0
101 } else {
102 self.read_contentions as f64 / total as f64
103 }
104 }
105
106 pub fn write_contention_rate(&self) -> f64 {
108 let total = self.write_locks + self.write_contentions;
109 if total == 0 {
110 0.0
111 } else {
112 self.write_contentions as f64 / total as f64
113 }
114 }
115
116 pub fn commit_rate(&self) -> f64 {
118 if self.transactions_started == 0 {
119 0.0
120 } else {
121 self.transactions_committed as f64 / self.transactions_started as f64
122 }
123 }
124}
125
126pub struct LockedSymbolTable {
131 table: RwLock<SymbolTable>,
132 stats: RwLock<LockStats>,
133}
134
135impl LockedSymbolTable {
136 pub fn new() -> Self {
138 Self {
139 table: RwLock::new(SymbolTable::new()),
140 stats: RwLock::new(LockStats::new()),
141 }
142 }
143
144 pub fn from_table(table: SymbolTable) -> Self {
146 Self {
147 table: RwLock::new(table),
148 stats: RwLock::new(LockStats::new()),
149 }
150 }
151
152 pub fn read(&self) -> RwLockReadGuard<'_, SymbolTable> {
157 let start = Instant::now();
158 let guard = self.table.read().unwrap();
159 let elapsed = start.elapsed().as_millis();
160
161 if let Ok(mut stats) = self.stats.write() {
162 stats.read_locks += 1;
163 stats.read_wait_ms += elapsed;
164 }
165
166 guard
167 }
168
169 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, SymbolTable>> {
173 match self.table.try_read() {
174 Ok(guard) => {
175 if let Ok(mut stats) = self.stats.write() {
176 stats.read_locks += 1;
177 }
178 Some(guard)
179 }
180 Err(TryLockError::WouldBlock) => {
181 if let Ok(mut stats) = self.stats.write() {
182 stats.read_contentions += 1;
183 }
184 None
185 }
186 Err(TryLockError::Poisoned(_)) => None,
187 }
188 }
189
190 pub fn write(&self) -> RwLockWriteGuard<'_, SymbolTable> {
195 let start = Instant::now();
196 let guard = self.table.write().unwrap();
197 let elapsed = start.elapsed().as_millis();
198
199 if let Ok(mut stats) = self.stats.write() {
200 stats.write_locks += 1;
201 stats.write_wait_ms += elapsed;
202 }
203
204 guard
205 }
206
207 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, SymbolTable>> {
211 match self.table.try_write() {
212 Ok(guard) => {
213 if let Ok(mut stats) = self.stats.write() {
214 stats.write_locks += 1;
215 }
216 Some(guard)
217 }
218 Err(TryLockError::WouldBlock) => {
219 if let Ok(mut stats) = self.stats.write() {
220 stats.write_contentions += 1;
221 }
222 None
223 }
224 Err(TryLockError::Poisoned(_)) => None,
225 }
226 }
227
228 pub fn stats(&self) -> LockStats {
230 self.stats.read().unwrap().clone()
231 }
232
233 pub fn reset_stats(&self) {
235 *self.stats.write().unwrap() = LockStats::new();
236 }
237
238 pub fn begin_transaction(&self) -> Transaction<'_> {
242 if let Ok(mut stats) = self.stats.write() {
243 stats.transactions_started += 1;
244 }
245 Transaction::new(self)
246 }
247}
248
249impl Default for LockedSymbolTable {
250 fn default() -> Self {
251 Self::new()
252 }
253}
254
255pub struct Transaction<'a> {
260 locked_table: &'a LockedSymbolTable,
261 snapshot: Option<SymbolTable>,
262 committed: bool,
263}
264
265impl<'a> Transaction<'a> {
266 fn new(locked_table: &'a LockedSymbolTable) -> Self {
267 let snapshot = locked_table.read().clone();
269 Self {
270 locked_table,
271 snapshot: Some(snapshot),
272 committed: false,
273 }
274 }
275
276 pub fn execute<F, R>(&mut self, f: F) -> Result<R, AdapterError>
280 where
281 F: FnOnce(&mut SymbolTable) -> Result<R, AdapterError>,
282 {
283 let mut guard = self.locked_table.write();
284 f(&mut guard)
285 }
286
287 pub fn commit(mut self) {
289 self.committed = true;
290 if let Ok(mut stats) = self.locked_table.stats.write() {
291 stats.transactions_committed += 1;
292 }
293 self.snapshot = None;
295 }
296
297 pub fn rollback(mut self) {
299 if let Some(snapshot) = self.snapshot.take() {
300 *self.locked_table.write() = snapshot;
301 }
302 if let Ok(mut stats) = self.locked_table.stats.write() {
303 stats.transactions_rolled_back += 1;
304 }
305 }
306}
307
308impl<'a> Drop for Transaction<'a> {
309 fn drop(&mut self) {
310 if !self.committed {
312 if let Some(snapshot) = self.snapshot.take() {
313 if let Ok(mut guard) = self.locked_table.table.write() {
314 *guard = snapshot;
315 }
316 if let Ok(mut stats) = self.locked_table.stats.write() {
317 stats.transactions_rolled_back += 1;
318 }
319 }
320 }
321 }
322}
323
324pub trait LockWithTimeout {
326 fn read_timeout(&self, timeout: Duration) -> Option<RwLockReadGuard<'_, SymbolTable>>;
330
331 fn write_timeout(&self, timeout: Duration) -> Option<RwLockWriteGuard<'_, SymbolTable>>;
335}
336
337impl LockWithTimeout for LockedSymbolTable {
338 fn read_timeout(&self, timeout: Duration) -> Option<RwLockReadGuard<'_, SymbolTable>> {
339 let start = Instant::now();
340 loop {
341 if let Some(guard) = self.try_read() {
342 return Some(guard);
343 }
344 if start.elapsed() >= timeout {
345 if let Ok(mut stats) = self.stats.write() {
346 stats.read_contentions += 1;
347 }
348 return None;
349 }
350 std::thread::sleep(Duration::from_millis(1));
351 }
352 }
353
354 fn write_timeout(&self, timeout: Duration) -> Option<RwLockWriteGuard<'_, SymbolTable>> {
355 let start = Instant::now();
356 loop {
357 if let Some(guard) = self.try_write() {
358 return Some(guard);
359 }
360 if start.elapsed() >= timeout {
361 if let Ok(mut stats) = self.stats.write() {
362 stats.write_contentions += 1;
363 }
364 return None;
365 }
366 std::thread::sleep(Duration::from_millis(1));
367 }
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::DomainInfo;
375 use std::sync::Arc;
376 use std::thread;
377
378 #[test]
379 fn test_basic_read_write() {
380 let table = LockedSymbolTable::new();
381
382 {
384 let mut guard = table.write();
385 guard.add_domain(DomainInfo::new("User", 100)).unwrap();
386 }
387
388 {
390 let guard = table.read();
391 assert_eq!(guard.domains.len(), 1);
392 assert!(guard.get_domain("User").is_some());
393 }
394 }
395
396 #[test]
397 fn test_multiple_readers() {
398 let table = Arc::new(LockedSymbolTable::new());
399
400 {
402 let mut guard = table.write();
403 guard.add_domain(DomainInfo::new("User", 100)).unwrap();
404 }
405
406 let mut handles = vec![];
408 for _ in 0..5 {
409 let table_clone = Arc::clone(&table);
410 handles.push(thread::spawn(move || {
411 let guard = table_clone.read();
412 assert_eq!(guard.domains.len(), 1);
413 }));
414 }
415
416 for handle in handles {
417 handle.join().unwrap();
418 }
419 }
420
421 #[test]
422 fn test_try_read_write() {
423 let table = LockedSymbolTable::new();
424
425 {
427 let guard = table.try_read();
428 assert!(guard.is_some());
429 }
430
431 {
433 let guard = table.try_write();
434 assert!(guard.is_some());
435 }
436 }
437
438 #[test]
439 fn test_try_write_contention() {
440 let table = Arc::new(LockedSymbolTable::new());
441
442 let _read_guard = table.read();
444
445 let table_clone = Arc::clone(&table);
447 let handle = thread::spawn(move || {
448 let guard = table_clone.try_write();
449 assert!(guard.is_none());
450 });
451
452 handle.join().unwrap();
453
454 let stats = table.stats();
456 assert!(stats.write_contentions > 0);
457 }
458
459 #[test]
460 fn test_transaction_commit() {
461 let table = LockedSymbolTable::new();
462
463 {
464 let mut txn = table.begin_transaction();
465 txn.execute(|t| {
466 t.add_domain(DomainInfo::new("User", 100)).unwrap();
467 t.add_domain(DomainInfo::new("Post", 1000)).unwrap();
468 Ok(())
469 })
470 .unwrap();
471 txn.commit();
472 }
473
474 let guard = table.read();
475 assert_eq!(guard.domains.len(), 2);
476
477 let stats = table.stats();
478 assert_eq!(stats.transactions_committed, 1);
479 }
480
481 #[test]
482 fn test_transaction_rollback() {
483 let table = LockedSymbolTable::new();
484
485 {
487 let mut guard = table.write();
488 guard.add_domain(DomainInfo::new("User", 100)).unwrap();
489 }
490
491 {
492 let mut txn = table.begin_transaction();
493 txn.execute(|t| {
494 t.add_domain(DomainInfo::new("Post", 1000)).unwrap();
495 Ok(())
496 })
497 .unwrap();
498 txn.rollback();
499 }
500
501 let guard = table.read();
502 assert_eq!(guard.domains.len(), 1);
503 assert!(guard.get_domain("Post").is_none());
504
505 let stats = table.stats();
506 assert_eq!(stats.transactions_rolled_back, 1);
507 }
508
509 #[test]
510 fn test_transaction_auto_rollback() {
511 let table = LockedSymbolTable::new();
512
513 {
514 let mut txn = table.begin_transaction();
515 txn.execute(|t| {
516 t.add_domain(DomainInfo::new("User", 100)).unwrap();
517 Ok(())
518 })
519 .unwrap();
520 }
522
523 let guard = table.read();
524 assert_eq!(guard.domains.len(), 0);
525
526 let stats = table.stats();
527 assert_eq!(stats.transactions_rolled_back, 1);
528 }
529
530 #[test]
531 fn test_lock_stats() {
532 let table = LockedSymbolTable::new();
533
534 for _ in 0..3 {
536 let _guard = table.read();
537 }
538
539 for _ in 0..2 {
541 let _guard = table.write();
542 }
543
544 let stats = table.stats();
545 assert_eq!(stats.read_locks, 3);
546 assert_eq!(stats.write_locks, 2);
547 }
548
549 #[test]
550 fn test_reset_stats() {
551 let table = LockedSymbolTable::new();
552
553 let _guard = table.read();
554 assert_eq!(table.stats().read_locks, 1);
555
556 table.reset_stats();
557 assert_eq!(table.stats().read_locks, 0);
558 }
559
560 #[test]
561 fn test_timeout_success() {
562 let table = LockedSymbolTable::new();
563
564 let guard = table.read_timeout(Duration::from_millis(100));
565 assert!(guard.is_some());
566 }
567
568 #[test]
569 fn test_timeout_failure() {
570 let table = Arc::new(LockedSymbolTable::new());
571
572 let _write_guard = table.write();
574
575 let table_clone = Arc::clone(&table);
577 let handle = thread::spawn(move || {
578 let guard = table_clone.write_timeout(Duration::from_millis(50));
579 assert!(guard.is_none());
580 });
581
582 handle.join().unwrap();
583 }
584
585 #[test]
586 fn test_concurrent_read_write() {
587 let table = Arc::new(LockedSymbolTable::new());
588
589 {
591 let mut guard = table.write();
592 guard.add_domain(DomainInfo::new("User", 100)).unwrap();
593 }
594
595 let mut handles = vec![];
596
597 for _ in 0..3 {
599 let table_clone = Arc::clone(&table);
600 handles.push(thread::spawn(move || {
601 for _ in 0..10 {
602 let guard = table_clone.read();
603 assert!(!guard.domains.is_empty());
604 thread::sleep(Duration::from_millis(1));
605 }
606 }));
607 }
608
609 for i in 0..2 {
611 let table_clone = Arc::clone(&table);
612 handles.push(thread::spawn(move || {
613 for j in 0..5 {
614 let mut guard = table_clone.write();
615 let domain_name = format!("Domain_{}_{}", i, j);
616 guard
617 .add_domain(DomainInfo::new(&domain_name, 100))
618 .unwrap();
619 thread::sleep(Duration::from_millis(2));
620 }
621 }));
622 }
623
624 for handle in handles {
625 handle.join().unwrap();
626 }
627
628 let guard = table.read();
630 assert!(guard.domains.len() >= 11); let stats = table.stats();
634 assert!(stats.read_locks > 0);
635 assert!(stats.write_locks > 0);
636 }
637
638 #[test]
639 fn test_stats_calculations() {
640 let mut stats = LockStats::new();
641 stats.read_locks = 10;
642 stats.write_locks = 5;
643 stats.read_wait_ms = 100;
644 stats.write_wait_ms = 200;
645 stats.read_contentions = 2;
646 stats.write_contentions = 3;
647 stats.transactions_started = 10;
648 stats.transactions_committed = 8;
649
650 assert_eq!(stats.avg_read_wait_ms(), 10.0);
651 assert_eq!(stats.avg_write_wait_ms(), 40.0);
652 assert!((stats.read_contention_rate() - 0.1667).abs() < 0.001);
653 assert_eq!(stats.write_contention_rate(), 0.375);
654 assert_eq!(stats.commit_rate(), 0.8);
655 }
656
657 #[test]
658 fn test_transaction_error_handling() {
659 let table = LockedSymbolTable::new();
660
661 let result: Result<(), AdapterError> = {
662 let mut txn = table.begin_transaction();
663 txn.execute(|t| {
664 t.add_domain(DomainInfo::new("User", 100)).unwrap();
665 Err(AdapterError::DuplicateDomain("User".to_string()))
667 })
668 };
669
670 assert!(result.is_err());
671
672 let guard = table.read();
674 assert_eq!(guard.domains.len(), 0);
675 }
676
677 #[test]
678 fn test_from_table() {
679 let mut original = SymbolTable::new();
680 original.add_domain(DomainInfo::new("User", 100)).unwrap();
681
682 let locked = LockedSymbolTable::from_table(original);
683
684 let guard = locked.read();
685 assert_eq!(guard.domains.len(), 1);
686 assert!(guard.get_domain("User").is_some());
687 }
688}