1use std::cmp::Reverse;
8
9use super::{AccessKind, CriticalSectionV2, LockSuggestion, LockTrackerV2, LockType};
10
11pub struct LockGranularityAnalyzerV2<'a> {
19 tracker: &'a LockTrackerV2,
20}
21
22impl<'a> LockGranularityAnalyzerV2<'a> {
23 pub fn new(tracker: &'a LockTrackerV2) -> Self {
25 Self { tracker }
26 }
27
28 pub fn analyze(&self) -> Vec<LockSuggestion> {
30 let mut suggestions = Vec::new();
31
32 for cs in self.tracker.critical_sections() {
33 suggestions.extend(self.analyze_critical_section(cs));
34 }
35
36 suggestions.sort_by_key(|b| Reverse(b.severity()));
38 suggestions
39 }
40
41 fn analyze_critical_section(&self, cs: &CriticalSectionV2) -> Vec<LockSuggestion> {
43 let mut suggestions = Vec::new();
44
45 if cs.contains_await {
47 suggestions.push(LockSuggestion::LockAcrossAwait {
48 guard_name: cs.acquisition.guard_name.clone(),
49 lock_line: cs.acquisition.line,
50 await_line: cs.start_line, });
52 }
53
54 suggestions.extend(self.check_atomic_opportunity(cs));
56
57 if let Some(split) = self.check_split_opportunity(cs) {
59 suggestions.push(split);
60 }
61
62 if let Some(rwlock) = self.check_rwlock_opportunity(cs) {
64 suggestions.push(rwlock);
65 }
66
67 if let Some(reduce) = self.check_scope_reduction(cs) {
69 suggestions.push(reduce);
70 }
71
72 suggestions
73 }
74
75 fn check_atomic_opportunity(&self, cs: &CriticalSectionV2) -> Vec<LockSuggestion> {
77 let mut suggestions = Vec::new();
78 let unique_fields = cs.unique_fields();
79
80 if unique_fields.len() == 1 {
82 let field = unique_fields[0];
83 let access_kind = cs.field_access_kind(field);
84
85 let suggested_type = self.suggest_atomic_type(field, access_kind);
87
88 if let Some(atomic_type) = suggested_type {
89 suggestions.push(LockSuggestion::UseAtomic {
90 field: field.to_string(),
91 current_type: None,
92 suggested_type: atomic_type,
93 line: cs.start_line,
94 });
95 }
96 }
97
98 suggestions
99 }
100
101 fn suggest_atomic_type(&self, field: &str, _access_kind: Option<AccessKind>) -> Option<String> {
103 let field_lower = field.to_lowercase();
104
105 if field_lower.contains("count")
107 || field_lower.contains("counter")
108 || field_lower.contains("num")
109 || field_lower.contains("total")
110 || field_lower.contains("size")
111 || field_lower.contains("len")
112 {
113 return Some("AtomicUsize".to_string());
114 }
115
116 if field_lower.contains("flag")
118 || field_lower.contains("enabled")
119 || field_lower.contains("active")
120 || field_lower.contains("ready")
121 || field_lower.contains("done")
122 || field_lower.contains("is_")
123 {
124 return Some("AtomicBool".to_string());
125 }
126
127 if field_lower.contains("id")
129 || field_lower.contains("index")
130 || field_lower.contains("seq")
131 {
132 return Some("AtomicU64".to_string());
133 }
134
135 None
136 }
137
138 fn check_split_opportunity(&self, cs: &CriticalSectionV2) -> Option<LockSuggestion> {
140 let unique_fields = cs.unique_fields();
141
142 if unique_fields.len() < 2 {
144 return None;
145 }
146
147 let mut suggested_splits = Vec::new();
149
150 for field in unique_fields {
151 let access_kind = cs.field_access_kind(field);
152 let wrapper = match access_kind {
153 Some(AccessKind::Read) => "Arc<RwLock<_>>".to_string(),
154 Some(AccessKind::Write) | Some(AccessKind::ReadWrite) => {
155 "Arc<Mutex<_>>".to_string()
156 }
157 None => continue,
158 };
159 suggested_splits.push((field.to_string(), wrapper));
160 }
161
162 if suggested_splits.len() >= 2 {
163 Some(LockSuggestion::SplitLock {
164 lock_name: cs.acquisition.lock_name.clone(),
165 suggested_splits,
166 line: cs.acquisition.line,
167 })
168 } else {
169 None
170 }
171 }
172
173 fn check_rwlock_opportunity(&self, cs: &CriticalSectionV2) -> Option<LockSuggestion> {
175 if !matches!(
177 cs.acquisition.lock_type,
178 LockType::Mutex | LockType::ParkingLotMutex | LockType::TokioMutex
179 ) {
180 return None;
181 }
182
183 let mut read_count = 0;
184 let mut write_count = 0;
185
186 for access in &cs.field_accesses {
187 match access.access_kind {
188 AccessKind::Read => read_count += 1,
189 AccessKind::Write => write_count += 1,
190 AccessKind::ReadWrite => {
191 read_count += 1;
192 write_count += 1;
193 }
194 }
195 }
196
197 if read_count > write_count * 2 && read_count >= 3 {
199 Some(LockSuggestion::UseRwLock {
200 lock_name: cs.acquisition.lock_name.clone(),
201 read_count,
202 write_count,
203 line: cs.acquisition.line,
204 })
205 } else {
206 None
207 }
208 }
209
210 fn check_scope_reduction(&self, cs: &CriticalSectionV2) -> Option<LockSuggestion> {
212 let end_line = cs.end_line?;
213 let span = end_line.saturating_sub(cs.start_line);
214
215 if span > 5 && cs.contains_expensive_ops {
217 let first_access = cs.field_accesses.iter().map(|a| a.line).min()?;
219 let last_access = cs.field_accesses.iter().map(|a| a.line).max()?;
220
221 if first_access > cs.start_line + 2 || end_line > last_access + 2 {
223 return Some(LockSuggestion::ReduceScope {
224 guard_name: cs.acquisition.guard_name.clone(),
225 current_span: (cs.start_line, end_line),
226 suggested_span: (first_access.saturating_sub(1), last_access + 1),
227 reason: "lock held across non-critical operations".to_string(),
228 });
229 }
230 }
231
232 None
233 }
234
235 pub fn stats(&self) -> LockStatsV2 {
237 let sections = self.tracker.critical_sections();
238
239 let mut mutex_count = 0;
240 let mut rwlock_count = 0;
241 let mut refcell_count = 0;
242 let mut total_field_accesses = 0;
243 let mut max_cs_span = 0u32;
244
245 for cs in sections {
246 match cs.acquisition.lock_type {
247 LockType::Mutex | LockType::ParkingLotMutex | LockType::TokioMutex => {
248 mutex_count += 1
249 }
250 LockType::RwLockRead
251 | LockType::RwLockWrite
252 | LockType::ParkingLotRwLock
253 | LockType::TokioRwLock => rwlock_count += 1,
254 LockType::RefCell | LockType::RefCellMut => refcell_count += 1,
255 }
256
257 total_field_accesses += cs.field_accesses.len();
258
259 if let Some(span) = cs.span() {
260 max_cs_span = max_cs_span.max(span);
261 }
262 }
263
264 LockStatsV2 {
265 total_locks: sections.len(),
266 mutex_count,
267 rwlock_count,
268 refcell_count,
269 total_field_accesses,
270 max_cs_span,
271 }
272 }
273
274 pub fn tracker(&self) -> &LockTrackerV2 {
276 self.tracker
277 }
278}
279
280#[derive(Debug, Clone, Default)]
282pub struct LockStatsV2 {
283 pub total_locks: usize,
285 pub mutex_count: usize,
287 pub rwlock_count: usize,
289 pub refcell_count: usize,
291 pub total_field_accesses: usize,
293 pub max_cs_span: u32,
295}
296
297#[cfg(test)]
298mod tests {
299 use super::super::{LockAcquisitionV2, VarSymbolMapping};
300 use super::*;
301 use crate::symbol::SymbolId;
302 use crate::VarId;
303 use slotmap::SlotMap;
304
305 struct TestVars {
307 symbols: SlotMap<SymbolId, &'static str>,
308 mapping: VarSymbolMapping,
309 }
310
311 impl TestVars {
312 fn new() -> Self {
313 Self {
314 symbols: SlotMap::with_key(),
315 mapping: VarSymbolMapping::new(),
316 }
317 }
318
319 fn var(&mut self, name: &'static str) -> VarId {
320 let sym = self.symbols.insert(name);
321 self.mapping.register(sym)
322 }
323 }
324
325 #[test]
326 fn test_atomic_suggestion_counter() {
327 let mut tracker = LockTrackerV2::new();
328 let mut vars = TestVars::new();
329
330 let lock_var = vars.var("lock");
331 let guard_var = vars.var("guard");
332
333 tracker.acquire(LockAcquisitionV2::new(
335 lock_var,
336 guard_var,
337 LockType::Mutex,
338 10,
339 "mutex",
340 "guard",
341 ));
342
343 tracker.record_field_access(guard_var, "counter", AccessKind::Write, 11);
345
346 tracker.release(guard_var, 15);
348
349 let analyzer = LockGranularityAnalyzerV2::new(&tracker);
350 let suggestions = analyzer.analyze();
351
352 assert!(suggestions
353 .iter()
354 .any(|s| matches!(s, LockSuggestion::UseAtomic { field, .. } if field == "counter")));
355 }
356
357 #[test]
358 fn test_rwlock_suggestion() {
359 let mut tracker = LockTrackerV2::new();
360 let mut vars = TestVars::new();
361
362 let lock_var = vars.var("lock");
363 let guard_var = vars.var("guard");
364
365 tracker.acquire(LockAcquisitionV2::new(
366 lock_var,
367 guard_var,
368 LockType::Mutex,
369 10,
370 "cache",
371 "guard",
372 ));
373
374 tracker.record_field_access(guard_var, "data", AccessKind::Read, 11);
376 tracker.record_field_access(guard_var, "data", AccessKind::Read, 12);
377 tracker.record_field_access(guard_var, "data", AccessKind::Read, 13);
378 tracker.record_field_access(guard_var, "data", AccessKind::Read, 14);
379 tracker.record_field_access(guard_var, "data", AccessKind::Write, 15);
380
381 tracker.release(guard_var, 20);
382
383 let analyzer = LockGranularityAnalyzerV2::new(&tracker);
384 let suggestions = analyzer.analyze();
385
386 assert!(suggestions.iter().any(
387 |s| matches!(s, LockSuggestion::UseRwLock { read_count, write_count, .. }
388 if *read_count == 4 && *write_count == 1)
389 ));
390 }
391
392 #[test]
393 fn test_split_lock_suggestion() {
394 let mut tracker = LockTrackerV2::new();
395 let mut vars = TestVars::new();
396
397 let lock_var = vars.var("lock");
398 let guard_var = vars.var("guard");
399
400 tracker.acquire(LockAcquisitionV2::new(
401 lock_var,
402 guard_var,
403 LockType::Mutex,
404 10,
405 "state",
406 "guard",
407 ));
408
409 tracker.record_field_access(guard_var, "counter", AccessKind::Write, 11);
411 tracker.record_field_access(guard_var, "name", AccessKind::Read, 12);
412 tracker.record_field_access(guard_var, "config", AccessKind::Read, 13);
413
414 tracker.release(guard_var, 20);
415
416 let analyzer = LockGranularityAnalyzerV2::new(&tracker);
417 let suggestions = analyzer.analyze();
418
419 assert!(suggestions.iter().any(
420 |s| matches!(s, LockSuggestion::SplitLock { suggested_splits, .. }
421 if suggested_splits.len() == 3)
422 ));
423 }
424
425 #[test]
426 fn test_lock_stats() {
427 let mut tracker = LockTrackerV2::new();
428 let mut vars = TestVars::new();
429
430 let lock1 = vars.var("lock1");
431 let guard1 = vars.var("guard1");
432 let lock2 = vars.var("lock2");
433 let guard2 = vars.var("guard2");
434
435 tracker.acquire(LockAcquisitionV2::new(
437 lock1,
438 guard1,
439 LockType::Mutex,
440 10,
441 "m1",
442 "g1",
443 ));
444 tracker.record_field_access(guard1, "field1", AccessKind::Write, 11);
445 tracker.release(guard1, 15);
446
447 tracker.acquire(LockAcquisitionV2::new(
448 lock2,
449 guard2,
450 LockType::RwLockRead,
451 20,
452 "r1",
453 "g2",
454 ));
455 tracker.record_field_access(guard2, "field2", AccessKind::Read, 21);
456 tracker.release(guard2, 25);
457
458 let analyzer = LockGranularityAnalyzerV2::new(&tracker);
459 let stats = analyzer.stats();
460
461 assert_eq!(stats.total_locks, 2);
462 assert_eq!(stats.mutex_count, 1);
463 assert_eq!(stats.rwlock_count, 1);
464 assert_eq!(stats.total_field_accesses, 2);
465 }
466}