1use super::var_id::VarId;
9use crate::symbol::SymbolId;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum LockType {
19 Mutex,
21 RwLockRead,
23 RwLockWrite,
25 RefCell,
27 RefCellMut,
29 ParkingLotMutex,
31 ParkingLotRwLock,
33 TokioMutex,
35 TokioRwLock,
37}
38
39impl LockType {
40 pub fn is_read_only(&self) -> bool {
42 matches!(self, LockType::RwLockRead | LockType::RefCell)
43 }
44
45 pub fn is_exclusive(&self) -> bool {
47 !self.is_read_only()
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53pub enum AccessKind {
54 Read,
56 Write,
58 ReadWrite,
60}
61
62impl AccessKind {
63 pub fn merge(self, other: Self) -> Self {
65 match (self, other) {
66 (AccessKind::Read, AccessKind::Read) => AccessKind::Read,
67 (AccessKind::Write, AccessKind::Write) => AccessKind::Write,
68 _ => AccessKind::ReadWrite,
69 }
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Eq)]
75pub enum LockSuggestion {
76 UseAtomic {
78 field: String,
80 current_type: Option<String>,
82 suggested_type: String,
84 line: u32,
86 },
87
88 SplitLock {
90 lock_name: String,
92 suggested_splits: Vec<(String, String)>,
95 line: u32,
97 },
98
99 ReduceScope {
101 guard_name: String,
103 current_span: (u32, u32),
105 suggested_span: (u32, u32),
107 reason: String,
109 },
110
111 UseRwLock {
113 lock_name: String,
115 read_count: usize,
117 write_count: usize,
119 line: u32,
121 },
122
123 LockAcrossAwait {
125 guard_name: String,
127 lock_line: u32,
129 await_line: u32,
131 },
132
133 RemoveLock {
135 lock_name: String,
137 reason: String,
139 line: u32,
141 },
142}
143
144impl LockSuggestion {
145 pub fn severity(&self) -> u8 {
147 match self {
148 LockSuggestion::LockAcrossAwait { .. } => 10, LockSuggestion::UseAtomic { .. } => 8, LockSuggestion::SplitLock { .. } => 6, LockSuggestion::ReduceScope { .. } => 5, LockSuggestion::UseRwLock { .. } => 4, LockSuggestion::RemoveLock { .. } => 3, }
155 }
156
157 pub fn description(&self) -> String {
159 match self {
160 LockSuggestion::UseAtomic {
161 field,
162 suggested_type,
163 ..
164 } => {
165 format!("Consider using {} for field '{}'", suggested_type, field)
166 }
167 LockSuggestion::SplitLock {
168 lock_name,
169 suggested_splits,
170 ..
171 } => {
172 let fields: Vec<_> = suggested_splits.iter().map(|(f, _)| f.as_str()).collect();
173 format!(
174 "Consider splitting lock '{}' for fields: {:?}",
175 lock_name, fields
176 )
177 }
178 LockSuggestion::ReduceScope {
179 guard_name, reason, ..
180 } => {
181 format!("Reduce scope of guard '{}': {}", guard_name, reason)
182 }
183 LockSuggestion::UseRwLock {
184 lock_name,
185 read_count,
186 write_count,
187 ..
188 } => {
189 format!(
190 "Consider RwLock for '{}' ({} reads, {} writes)",
191 lock_name, read_count, write_count
192 )
193 }
194 LockSuggestion::LockAcrossAwait { guard_name, .. } => {
195 format!("Lock '{}' is held across await point", guard_name)
196 }
197 LockSuggestion::RemoveLock {
198 lock_name, reason, ..
199 } => {
200 format!("Lock '{}' may be unnecessary: {}", lock_name, reason)
201 }
202 }
203 }
204
205 pub fn short_description(&self) -> String {
207 match self {
208 LockSuggestion::UseAtomic {
209 field,
210 suggested_type,
211 ..
212 } => {
213 format!("Use {} for field '{}'", suggested_type, field)
214 }
215 LockSuggestion::SplitLock {
216 lock_name,
217 suggested_splits,
218 ..
219 } => {
220 format!(
221 "Split '{}' into {} separate locks",
222 lock_name,
223 suggested_splits.len()
224 )
225 }
226 LockSuggestion::ReduceScope {
227 guard_name, reason, ..
228 } => {
229 format!("Reduce scope of '{}': {}", guard_name, reason)
230 }
231 LockSuggestion::UseRwLock {
232 lock_name,
233 read_count,
234 write_count,
235 ..
236 } => {
237 format!(
238 "Use RwLock for '{}' ({} reads, {} writes)",
239 lock_name, read_count, write_count
240 )
241 }
242 LockSuggestion::LockAcrossAwait {
243 guard_name,
244 lock_line,
245 await_line,
246 ..
247 } => {
248 format!(
249 "Lock '{}' held across await (lines {}-{})",
250 guard_name, lock_line, await_line
251 )
252 }
253 LockSuggestion::RemoveLock {
254 lock_name, reason, ..
255 } => {
256 format!("Remove unnecessary '{}': {}", lock_name, reason)
257 }
258 }
259 }
260}
261
262#[derive(Debug, Clone, PartialEq, Eq)]
268pub struct LockAcquisitionV2 {
269 pub lock_var: VarId,
271 pub guard_var: VarId,
273 pub lock_type: LockType,
275 pub line: u32,
277 pub is_try: bool,
279 pub lock_name: String,
281 pub guard_name: String,
283 pub owner_fn: Option<SymbolId>,
285}
286
287impl LockAcquisitionV2 {
288 pub fn new(
290 lock_var: VarId,
291 guard_var: VarId,
292 lock_type: LockType,
293 line: u32,
294 lock_name: impl Into<String>,
295 guard_name: impl Into<String>,
296 ) -> Self {
297 Self {
298 lock_var,
299 guard_var,
300 lock_type,
301 line,
302 is_try: false,
303 lock_name: lock_name.into(),
304 guard_name: guard_name.into(),
305 owner_fn: None,
306 }
307 }
308
309 pub fn with_owner_fn(mut self, owner: SymbolId) -> Self {
311 self.owner_fn = Some(owner);
312 self
313 }
314
315 pub fn with_try(mut self) -> Self {
317 self.is_try = true;
318 self
319 }
320}
321
322#[derive(Debug, Clone, PartialEq, Eq)]
328pub struct FieldAccessV2 {
329 pub field_name: String,
331 pub access_kind: AccessKind,
333 pub line: u32,
335}
336
337impl FieldAccessV2 {
338 pub fn new(field_name: impl Into<String>, access_kind: AccessKind, line: u32) -> Self {
340 Self {
341 field_name: field_name.into(),
342 access_kind,
343 line,
344 }
345 }
346}
347
348#[derive(Debug, Clone)]
354pub struct CriticalSectionV2 {
355 pub acquisition: LockAcquisitionV2,
357 pub start_line: u32,
359 pub end_line: Option<u32>,
361 pub field_accesses: Vec<FieldAccessV2>,
363 pub contains_expensive_ops: bool,
365 pub contains_await: bool,
367}
368
369impl CriticalSectionV2 {
370 pub fn new(acquisition: LockAcquisitionV2) -> Self {
372 let start_line = acquisition.line;
373 Self {
374 acquisition,
375 start_line,
376 end_line: None,
377 field_accesses: Vec::new(),
378 contains_expensive_ops: false,
379 contains_await: false,
380 }
381 }
382
383 pub fn end_at(&mut self, line: u32) {
385 self.end_line = Some(line);
386 }
387
388 pub fn add_field_access(&mut self, access: FieldAccessV2) {
390 self.field_accesses.push(access);
391 }
392
393 pub fn mark_expensive(&mut self) {
395 self.contains_expensive_ops = true;
396 }
397
398 pub fn mark_await(&mut self) {
400 self.contains_await = true;
401 }
402
403 pub fn unique_fields(&self) -> Vec<&str> {
405 let mut fields: Vec<&str> = self
406 .field_accesses
407 .iter()
408 .map(|a| a.field_name.as_str())
409 .collect();
410 fields.sort();
411 fields.dedup();
412 fields
413 }
414
415 pub fn field_access_kind(&self, field: &str) -> Option<AccessKind> {
417 self.field_accesses
418 .iter()
419 .filter(|a| a.field_name == field)
420 .map(|a| a.access_kind)
421 .reduce(|a, b| a.merge(b))
422 }
423
424 pub fn is_read_only(&self) -> bool {
426 self.field_accesses
427 .iter()
428 .all(|a| a.access_kind == AccessKind::Read)
429 }
430
431 pub fn span(&self) -> Option<u32> {
433 self.end_line.map(|end| end.saturating_sub(self.start_line))
434 }
435}
436
437#[derive(Debug, Clone, Default)]
443pub struct LockTrackerV2 {
444 acquisitions: Vec<LockAcquisitionV2>,
446 active_sections: HashMap<VarId, CriticalSectionV2>,
448 completed_sections: Vec<CriticalSectionV2>,
450}
451
452impl LockTrackerV2 {
453 pub fn new() -> Self {
455 Self::default()
456 }
457
458 pub fn acquire(&mut self, acquisition: LockAcquisitionV2) {
460 let guard_var = acquisition.guard_var;
461 self.acquisitions.push(acquisition.clone());
462 self.active_sections
463 .insert(guard_var, CriticalSectionV2::new(acquisition));
464 }
465
466 pub fn record_field_access(
468 &mut self,
469 guard_var: VarId,
470 field_name: &str,
471 access_kind: AccessKind,
472 line: u32,
473 ) {
474 if let Some(cs) = self.active_sections.get_mut(&guard_var) {
475 cs.add_field_access(FieldAccessV2::new(field_name, access_kind, line));
476 }
477 }
478
479 pub fn mark_expensive(&mut self, guard_var: VarId) {
481 if let Some(cs) = self.active_sections.get_mut(&guard_var) {
482 cs.mark_expensive();
483 }
484 }
485
486 pub fn mark_await(&mut self, guard_var: VarId, _await_line: u32) {
488 if let Some(cs) = self.active_sections.get_mut(&guard_var) {
489 cs.mark_await();
490 }
491 }
492
493 pub fn release(&mut self, guard_var: VarId, line: u32) {
495 if let Some(mut cs) = self.active_sections.remove(&guard_var) {
496 cs.end_at(line);
497 self.completed_sections.push(cs);
498 }
499 }
500
501 pub fn critical_sections(&self) -> &[CriticalSectionV2] {
503 &self.completed_sections
504 }
505
506 pub fn acquisitions(&self) -> &[LockAcquisitionV2] {
508 &self.acquisitions
509 }
510
511 pub fn acquisitions_by_owner(&self, owner: SymbolId) -> Vec<&LockAcquisitionV2> {
513 self.acquisitions
514 .iter()
515 .filter(|a| a.owner_fn == Some(owner))
516 .collect()
517 }
518
519 pub fn active_sections(&self) -> impl Iterator<Item = &CriticalSectionV2> {
521 self.active_sections.values()
522 }
523
524 pub fn is_active(&self, guard_var: VarId) -> bool {
526 self.active_sections.contains_key(&guard_var)
527 }
528
529 pub fn clear(&mut self) {
531 self.acquisitions.clear();
532 self.active_sections.clear();
533 self.completed_sections.clear();
534 }
535
536 pub fn completed_count(&self) -> usize {
538 self.completed_sections.len()
539 }
540
541 pub fn active_count(&self) -> usize {
543 self.active_sections.len()
544 }
545
546 pub fn flush_active_sections(&mut self) {
551 let active = std::mem::take(&mut self.active_sections);
552 for (_, cs) in active {
553 self.completed_sections.push(cs);
554 }
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561 use crate::symbol::SymbolId;
562 use slotmap::SlotMap;
563
564 struct TestVars {
566 symbols: SlotMap<SymbolId, &'static str>,
567 mapping: super::super::var_id::VarSymbolMapping,
568 }
569
570 impl TestVars {
571 fn new() -> Self {
572 Self {
573 symbols: SlotMap::with_key(),
574 mapping: super::super::var_id::VarSymbolMapping::new(),
575 }
576 }
577
578 fn var(&mut self, name: &'static str) -> VarId {
579 let sym = self.symbols.insert(name);
580 self.mapping.register(sym)
581 }
582 }
583
584 #[test]
585 fn test_lock_acquisition_v2() {
586 let mut vars = TestVars::new();
587 let lock = vars.var("lock");
588 let guard = vars.var("guard");
589 let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 10, "mutex", "guard");
590
591 assert_eq!(acq.lock_var, lock);
592 assert_eq!(acq.guard_var, guard);
593 assert_eq!(acq.lock_type, LockType::Mutex);
594 assert_eq!(acq.line, 10);
595 assert!(!acq.is_try);
596
597 let try_acq = acq.with_try();
598 assert!(try_acq.is_try);
599 }
600
601 #[test]
602 fn test_critical_section_field_tracking() {
603 let mut vars = TestVars::new();
604 let lock = vars.var("lock");
605 let guard = vars.var("guard");
606 let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 10, "mutex", "guard");
607 let mut cs = CriticalSectionV2::new(acq);
608
609 cs.add_field_access(FieldAccessV2::new("counter", AccessKind::Read, 11));
610 cs.add_field_access(FieldAccessV2::new("counter", AccessKind::Write, 12));
611 cs.add_field_access(FieldAccessV2::new("name", AccessKind::Read, 13));
612
613 let fields = cs.unique_fields();
614 assert_eq!(fields.len(), 2);
615
616 assert_eq!(cs.field_access_kind("counter"), Some(AccessKind::ReadWrite));
617 assert_eq!(cs.field_access_kind("name"), Some(AccessKind::Read));
618 }
619
620 #[test]
621 fn test_critical_section_is_read_only() {
622 let mut vars = TestVars::new();
623 let lock = vars.var("lock");
624 let guard = vars.var("guard");
625 let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 10, "mutex", "guard");
626 let mut cs = CriticalSectionV2::new(acq);
627
628 cs.add_field_access(FieldAccessV2::new("a", AccessKind::Read, 11));
629 cs.add_field_access(FieldAccessV2::new("b", AccessKind::Read, 12));
630 assert!(cs.is_read_only());
631
632 cs.add_field_access(FieldAccessV2::new("c", AccessKind::Write, 13));
633 assert!(!cs.is_read_only());
634 }
635
636 #[test]
637 fn test_lock_tracker_lifecycle() {
638 let mut tracker = LockTrackerV2::new();
639 let mut vars = TestVars::new();
640 let lock = vars.var("lock");
641 let guard = vars.var("guard");
642
643 let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 10, "mutex", "guard");
644 tracker.acquire(acq);
645
646 assert!(tracker.is_active(guard));
647 assert_eq!(tracker.acquisitions().len(), 1);
648 assert_eq!(tracker.active_count(), 1);
649
650 tracker.record_field_access(guard, "counter", AccessKind::Write, 11);
651 tracker.release(guard, 15);
652
653 assert!(!tracker.is_active(guard));
654 assert_eq!(tracker.critical_sections().len(), 1);
655 assert_eq!(tracker.completed_count(), 1);
656 assert_eq!(tracker.active_count(), 0);
657
658 let cs = &tracker.critical_sections()[0];
659 assert_eq!(cs.start_line, 10);
660 assert_eq!(cs.end_line, Some(15));
661 assert_eq!(cs.field_accesses.len(), 1);
662 }
663
664 #[test]
665 fn test_lock_tracker_mark_expensive_await() {
666 let mut tracker = LockTrackerV2::new();
667 let mut vars = TestVars::new();
668 let lock = vars.var("lock");
669 let guard = vars.var("guard");
670
671 let acq = LockAcquisitionV2::new(lock, guard, LockType::TokioMutex, 10, "mutex", "guard");
672 tracker.acquire(acq);
673
674 tracker.mark_expensive(guard);
675 tracker.mark_await(guard, 12);
676 tracker.release(guard, 15);
677
678 let cs = &tracker.critical_sections()[0];
679 assert!(cs.contains_expensive_ops);
680 assert!(cs.contains_await);
681 }
682
683 #[test]
684 fn test_lock_tracker_clear() {
685 let mut tracker = LockTrackerV2::new();
686 let mut vars = TestVars::new();
687 let lock = vars.var("lock");
688 let guard = vars.var("guard");
689
690 let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 10, "mutex", "guard");
691 tracker.acquire(acq);
692 tracker.release(guard, 15);
693
694 assert_eq!(tracker.completed_count(), 1);
695
696 tracker.clear();
697
698 assert_eq!(tracker.completed_count(), 0);
699 assert_eq!(tracker.acquisitions().len(), 0);
700 }
701
702 #[test]
703 fn test_flush_active_sections() {
704 let mut tracker = LockTrackerV2::new();
705 let mut vars = TestVars::new();
706 let lock1 = vars.var("lock1");
707 let guard1 = vars.var("guard1");
708 let lock2 = vars.var("lock2");
709 let guard2 = vars.var("guard2");
710
711 let acq1 = LockAcquisitionV2::new(lock1, guard1, LockType::Mutex, 0, "m1", "g1");
712 let acq2 = LockAcquisitionV2::new(lock2, guard2, LockType::RwLockRead, 0, "r1", "g2");
713 tracker.acquire(acq1);
714 tracker.acquire(acq2);
715
716 assert_eq!(tracker.active_count(), 2);
717 assert_eq!(tracker.completed_count(), 0);
718
719 tracker.flush_active_sections();
720
721 assert_eq!(tracker.active_count(), 0);
722 assert_eq!(tracker.completed_count(), 2);
723 assert!(tracker.critical_sections()[0].end_line.is_none());
725 assert!(tracker.critical_sections()[1].end_line.is_none());
726 }
727
728 #[test]
729 fn test_flush_preserves_acquisitions() {
730 let mut tracker = LockTrackerV2::new();
731 let mut vars = TestVars::new();
732 let lock = vars.var("lock");
733 let guard = vars.var("guard");
734
735 let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 0, "mutex", "guard");
736 tracker.acquire(acq);
737 tracker.flush_active_sections();
738
739 assert_eq!(tracker.acquisitions().len(), 1);
741 assert_eq!(tracker.completed_count(), 1);
743 }
744}