tensorlogic_adapters/
locking.rs1use crate::{AdapterError, SymbolTable};
46use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockError};
47use std::time::{Duration, Instant};
48
49#[derive(Debug, Clone, Default)]
51pub struct LockStats {
52 pub read_locks: usize,
54 pub write_locks: usize,
56 pub read_contentions: usize,
58 pub write_contentions: usize,
60 pub read_wait_ms: u128,
62 pub write_wait_ms: u128,
64 pub transactions_started: usize,
66 pub transactions_committed: usize,
68 pub transactions_rolled_back: usize,
70}
71
72impl LockStats {
73 pub fn new() -> Self {
75 Self::default()
76 }
77
78 pub fn avg_read_wait_ms(&self) -> f64 {
80 if self.read_locks == 0 {
81 0.0
82 } else {
83 self.read_wait_ms as f64 / self.read_locks as f64
84 }
85 }
86
87 pub fn avg_write_wait_ms(&self) -> f64 {
89 if self.write_locks == 0 {
90 0.0
91 } else {
92 self.write_wait_ms as f64 / self.write_locks as f64
93 }
94 }
95
96 pub fn read_contention_rate(&self) -> f64 {
98 let total = self.read_locks + self.read_contentions;
99 if total == 0 {
100 0.0
101 } else {
102 self.read_contentions as f64 / total as f64
103 }
104 }
105
106 pub fn write_contention_rate(&self) -> f64 {
108 let total = self.write_locks + self.write_contentions;
109 if total == 0 {
110 0.0
111 } else {
112 self.write_contentions as f64 / total as f64
113 }
114 }
115
116 pub fn commit_rate(&self) -> f64 {
118 if self.transactions_started == 0 {
119 0.0
120 } else {
121 self.transactions_committed as f64 / self.transactions_started as f64
122 }
123 }
124}
125
126pub struct LockedSymbolTable {
131 table: RwLock<SymbolTable>,
132 stats: RwLock<LockStats>,
133}
134
135impl LockedSymbolTable {
136 pub fn new() -> Self {
138 Self {
139 table: RwLock::new(SymbolTable::new()),
140 stats: RwLock::new(LockStats::new()),
141 }
142 }
143
144 pub fn from_table(table: SymbolTable) -> Self {
146 Self {
147 table: RwLock::new(table),
148 stats: RwLock::new(LockStats::new()),
149 }
150 }
151
152 pub fn read(&self) -> RwLockReadGuard<'_, SymbolTable> {
157 let start = Instant::now();
158 let guard = self.table.read().expect("lock should not be poisoned");
159 let elapsed = start.elapsed().as_millis();
160
161 if let Ok(mut stats) = self.stats.write() {
162 stats.read_locks += 1;
163 stats.read_wait_ms += elapsed;
164 }
165
166 guard
167 }
168
169 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, SymbolTable>> {
173 match self.table.try_read() {
174 Ok(guard) => {
175 if let Ok(mut stats) = self.stats.write() {
176 stats.read_locks += 1;
177 }
178 Some(guard)
179 }
180 Err(TryLockError::WouldBlock) => {
181 if let Ok(mut stats) = self.stats.write() {
182 stats.read_contentions += 1;
183 }
184 None
185 }
186 Err(TryLockError::Poisoned(_)) => None,
187 }
188 }
189
190 pub fn write(&self) -> RwLockWriteGuard<'_, SymbolTable> {
195 let start = Instant::now();
196 let guard = self.table.write().expect("lock should not be poisoned");
197 let elapsed = start.elapsed().as_millis();
198
199 if let Ok(mut stats) = self.stats.write() {
200 stats.write_locks += 1;
201 stats.write_wait_ms += elapsed;
202 }
203
204 guard
205 }
206
207 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, SymbolTable>> {
211 match self.table.try_write() {
212 Ok(guard) => {
213 if let Ok(mut stats) = self.stats.write() {
214 stats.write_locks += 1;
215 }
216 Some(guard)
217 }
218 Err(TryLockError::WouldBlock) => {
219 if let Ok(mut stats) = self.stats.write() {
220 stats.write_contentions += 1;
221 }
222 None
223 }
224 Err(TryLockError::Poisoned(_)) => None,
225 }
226 }
227
228 pub fn stats(&self) -> LockStats {
230 self.stats
231 .read()
232 .expect("lock should not be poisoned")
233 .clone()
234 }
235
236 pub fn reset_stats(&self) {
238 *self.stats.write().expect("lock should not be poisoned") = LockStats::new();
239 }
240
241 pub fn begin_transaction(&self) -> Transaction<'_> {
245 if let Ok(mut stats) = self.stats.write() {
246 stats.transactions_started += 1;
247 }
248 Transaction::new(self)
249 }
250}
251
252impl Default for LockedSymbolTable {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258pub struct Transaction<'a> {
263 locked_table: &'a LockedSymbolTable,
264 snapshot: Option<SymbolTable>,
265 committed: bool,
266}
267
268impl<'a> Transaction<'a> {
269 fn new(locked_table: &'a LockedSymbolTable) -> Self {
270 let snapshot = locked_table.read().clone();
272 Self {
273 locked_table,
274 snapshot: Some(snapshot),
275 committed: false,
276 }
277 }
278
279 pub fn execute<F, R>(&mut self, f: F) -> Result<R, AdapterError>
283 where
284 F: FnOnce(&mut SymbolTable) -> Result<R, AdapterError>,
285 {
286 let mut guard = self.locked_table.write();
287 f(&mut guard)
288 }
289
290 pub fn commit(mut self) {
292 self.committed = true;
293 if let Ok(mut stats) = self.locked_table.stats.write() {
294 stats.transactions_committed += 1;
295 }
296 self.snapshot = None;
298 }
299
300 pub fn rollback(mut self) {
302 if let Some(snapshot) = self.snapshot.take() {
303 *self.locked_table.write() = snapshot;
304 }
305 if let Ok(mut stats) = self.locked_table.stats.write() {
306 stats.transactions_rolled_back += 1;
307 }
308 }
309}
310
311impl<'a> Drop for Transaction<'a> {
312 fn drop(&mut self) {
313 if !self.committed {
315 if let Some(snapshot) = self.snapshot.take() {
316 if let Ok(mut guard) = self.locked_table.table.write() {
317 *guard = snapshot;
318 }
319 if let Ok(mut stats) = self.locked_table.stats.write() {
320 stats.transactions_rolled_back += 1;
321 }
322 }
323 }
324 }
325}
326
327pub trait LockWithTimeout {
329 fn read_timeout(&self, timeout: Duration) -> Option<RwLockReadGuard<'_, SymbolTable>>;
333
334 fn write_timeout(&self, timeout: Duration) -> Option<RwLockWriteGuard<'_, SymbolTable>>;
338}
339
340impl LockWithTimeout for LockedSymbolTable {
341 fn read_timeout(&self, timeout: Duration) -> Option<RwLockReadGuard<'_, SymbolTable>> {
342 let start = Instant::now();
343 loop {
344 if let Some(guard) = self.try_read() {
345 return Some(guard);
346 }
347 if start.elapsed() >= timeout {
348 if let Ok(mut stats) = self.stats.write() {
349 stats.read_contentions += 1;
350 }
351 return None;
352 }
353 std::thread::sleep(Duration::from_millis(1));
354 }
355 }
356
357 fn write_timeout(&self, timeout: Duration) -> Option<RwLockWriteGuard<'_, SymbolTable>> {
358 let start = Instant::now();
359 loop {
360 if let Some(guard) = self.try_write() {
361 return Some(guard);
362 }
363 if start.elapsed() >= timeout {
364 if let Ok(mut stats) = self.stats.write() {
365 stats.write_contentions += 1;
366 }
367 return None;
368 }
369 std::thread::sleep(Duration::from_millis(1));
370 }
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::DomainInfo;
378 use std::sync::Arc;
379 use std::thread;
380
381 #[test]
382 fn test_basic_read_write() {
383 let table = LockedSymbolTable::new();
384
385 {
387 let mut guard = table.write();
388 guard
389 .add_domain(DomainInfo::new("User", 100))
390 .expect("unwrap");
391 }
392
393 {
395 let guard = table.read();
396 assert_eq!(guard.domains.len(), 1);
397 assert!(guard.get_domain("User").is_some());
398 }
399 }
400
401 #[test]
402 fn test_multiple_readers() {
403 let table = Arc::new(LockedSymbolTable::new());
404
405 {
407 let mut guard = table.write();
408 guard
409 .add_domain(DomainInfo::new("User", 100))
410 .expect("unwrap");
411 }
412
413 let mut handles = vec![];
415 for _ in 0..5 {
416 let table_clone = Arc::clone(&table);
417 handles.push(thread::spawn(move || {
418 let guard = table_clone.read();
419 assert_eq!(guard.domains.len(), 1);
420 }));
421 }
422
423 for handle in handles {
424 handle.join().expect("unwrap");
425 }
426 }
427
428 #[test]
429 fn test_try_read_write() {
430 let table = LockedSymbolTable::new();
431
432 {
434 let guard = table.try_read();
435 assert!(guard.is_some());
436 }
437
438 {
440 let guard = table.try_write();
441 assert!(guard.is_some());
442 }
443 }
444
445 #[test]
446 fn test_try_write_contention() {
447 let table = Arc::new(LockedSymbolTable::new());
448
449 let _read_guard = table.read();
451
452 let table_clone = Arc::clone(&table);
454 let handle = thread::spawn(move || {
455 let guard = table_clone.try_write();
456 assert!(guard.is_none());
457 });
458
459 handle.join().expect("unwrap");
460
461 let stats = table.stats();
463 assert!(stats.write_contentions > 0);
464 }
465
466 #[test]
467 fn test_transaction_commit() {
468 let table = LockedSymbolTable::new();
469
470 {
471 let mut txn = table.begin_transaction();
472 txn.execute(|t| {
473 t.add_domain(DomainInfo::new("User", 100)).expect("unwrap");
474 t.add_domain(DomainInfo::new("Post", 1000)).expect("unwrap");
475 Ok(())
476 })
477 .expect("unwrap");
478 txn.commit();
479 }
480
481 let guard = table.read();
482 assert_eq!(guard.domains.len(), 2);
483
484 let stats = table.stats();
485 assert_eq!(stats.transactions_committed, 1);
486 }
487
488 #[test]
489 fn test_transaction_rollback() {
490 let table = LockedSymbolTable::new();
491
492 {
494 let mut guard = table.write();
495 guard
496 .add_domain(DomainInfo::new("User", 100))
497 .expect("unwrap");
498 }
499
500 {
501 let mut txn = table.begin_transaction();
502 txn.execute(|t| {
503 t.add_domain(DomainInfo::new("Post", 1000)).expect("unwrap");
504 Ok(())
505 })
506 .expect("unwrap");
507 txn.rollback();
508 }
509
510 let guard = table.read();
511 assert_eq!(guard.domains.len(), 1);
512 assert!(guard.get_domain("Post").is_none());
513
514 let stats = table.stats();
515 assert_eq!(stats.transactions_rolled_back, 1);
516 }
517
518 #[test]
519 fn test_transaction_auto_rollback() {
520 let table = LockedSymbolTable::new();
521
522 {
523 let mut txn = table.begin_transaction();
524 txn.execute(|t| {
525 t.add_domain(DomainInfo::new("User", 100)).expect("unwrap");
526 Ok(())
527 })
528 .expect("unwrap");
529 }
531
532 let guard = table.read();
533 assert_eq!(guard.domains.len(), 0);
534
535 let stats = table.stats();
536 assert_eq!(stats.transactions_rolled_back, 1);
537 }
538
539 #[test]
540 fn test_lock_stats() {
541 let table = LockedSymbolTable::new();
542
543 for _ in 0..3 {
545 let _guard = table.read();
546 }
547
548 for _ in 0..2 {
550 let _guard = table.write();
551 }
552
553 let stats = table.stats();
554 assert_eq!(stats.read_locks, 3);
555 assert_eq!(stats.write_locks, 2);
556 }
557
558 #[test]
559 fn test_reset_stats() {
560 let table = LockedSymbolTable::new();
561
562 let _guard = table.read();
563 assert_eq!(table.stats().read_locks, 1);
564
565 table.reset_stats();
566 assert_eq!(table.stats().read_locks, 0);
567 }
568
569 #[test]
570 fn test_timeout_success() {
571 let table = LockedSymbolTable::new();
572
573 let guard = table.read_timeout(Duration::from_millis(100));
574 assert!(guard.is_some());
575 }
576
577 #[test]
578 fn test_timeout_failure() {
579 let table = Arc::new(LockedSymbolTable::new());
580
581 let _write_guard = table.write();
583
584 let table_clone = Arc::clone(&table);
586 let handle = thread::spawn(move || {
587 let guard = table_clone.write_timeout(Duration::from_millis(50));
588 assert!(guard.is_none());
589 });
590
591 handle.join().expect("unwrap");
592 }
593
594 #[test]
595 fn test_concurrent_read_write() {
596 let table = Arc::new(LockedSymbolTable::new());
597
598 {
600 let mut guard = table.write();
601 guard
602 .add_domain(DomainInfo::new("User", 100))
603 .expect("unwrap");
604 }
605
606 let mut handles = vec![];
607
608 for _ in 0..3 {
610 let table_clone = Arc::clone(&table);
611 handles.push(thread::spawn(move || {
612 for _ in 0..10 {
613 let guard = table_clone.read();
614 assert!(!guard.domains.is_empty());
615 thread::sleep(Duration::from_millis(1));
616 }
617 }));
618 }
619
620 for i in 0..2 {
622 let table_clone = Arc::clone(&table);
623 handles.push(thread::spawn(move || {
624 for j in 0..5 {
625 let mut guard = table_clone.write();
626 let domain_name = format!("Domain_{}_{}", i, j);
627 guard
628 .add_domain(DomainInfo::new(&domain_name, 100))
629 .expect("unwrap");
630 thread::sleep(Duration::from_millis(2));
631 }
632 }));
633 }
634
635 for handle in handles {
636 handle.join().expect("unwrap");
637 }
638
639 let guard = table.read();
641 assert!(guard.domains.len() >= 11); let stats = table.stats();
645 assert!(stats.read_locks > 0);
646 assert!(stats.write_locks > 0);
647 }
648
649 #[test]
650 fn test_stats_calculations() {
651 let mut stats = LockStats::new();
652 stats.read_locks = 10;
653 stats.write_locks = 5;
654 stats.read_wait_ms = 100;
655 stats.write_wait_ms = 200;
656 stats.read_contentions = 2;
657 stats.write_contentions = 3;
658 stats.transactions_started = 10;
659 stats.transactions_committed = 8;
660
661 assert_eq!(stats.avg_read_wait_ms(), 10.0);
662 assert_eq!(stats.avg_write_wait_ms(), 40.0);
663 assert!((stats.read_contention_rate() - 0.1667).abs() < 0.001);
664 assert_eq!(stats.write_contention_rate(), 0.375);
665 assert_eq!(stats.commit_rate(), 0.8);
666 }
667
668 #[test]
669 fn test_transaction_error_handling() {
670 let table = LockedSymbolTable::new();
671
672 let result: Result<(), AdapterError> = {
673 let mut txn = table.begin_transaction();
674 txn.execute(|t| {
675 t.add_domain(DomainInfo::new("User", 100)).expect("unwrap");
676 Err(AdapterError::DuplicateDomain("User".to_string()))
678 })
679 };
680
681 assert!(result.is_err());
682
683 let guard = table.read();
685 assert_eq!(guard.domains.len(), 0);
686 }
687
688 #[test]
689 fn test_from_table() {
690 let mut original = SymbolTable::new();
691 original
692 .add_domain(DomainInfo::new("User", 100))
693 .expect("unwrap");
694
695 let locked = LockedSymbolTable::from_table(original);
696
697 let guard = locked.read();
698 assert_eq!(guard.domains.len(), 1);
699 assert!(guard.get_domain("User").is_some());
700 }
701}