1use crate::ir::{IrFunction, IrStatement, BorrowKind, OwnershipState};
13use crate::parser::HeaderCache;
14use std::collections::{HashMap, HashSet};
15use crate::debug_println;
16
17#[derive(Debug, Clone)]
20pub struct ContainerBorrow {
21 pub container: String,
23 pub pointee: String,
25 pub container_scope: usize,
27 pub pointee_scope: usize,
29 pub line: usize,
31}
32
33#[derive(Debug, Clone)]
35pub struct IteratorBorrow {
36 pub iterator: String,
38 pub container: String,
40 pub iterator_scope: usize,
42 pub container_scope: usize,
44 pub line: usize,
46}
47
48#[derive(Debug, Clone)]
50pub struct LambdaCapture {
51 pub lambda_var: String,
53 pub ref_captures: Vec<String>,
55 pub lambda_scope: usize,
57 pub has_escaped: bool,
59 pub line: usize,
61}
62
63#[derive(Debug, Clone, PartialEq)]
65pub enum AllocationState {
66 Allocated,
68 Freed,
70}
71
72#[derive(Debug, Clone)]
74pub struct HeapAllocation {
75 pub variable: String,
76 pub state: AllocationState,
77 pub allocation_line: usize,
78 pub free_line: Option<usize>,
79}
80
81#[derive(Debug, Clone)]
84pub struct MemberBorrow {
85 pub reference: String,
87 pub object: String,
89 pub field: String,
91 pub reference_scope: usize,
93 pub object_scope: usize,
95 pub line: usize,
97}
98
99#[derive(Debug)]
101pub struct RaiiTracker {
102 pub container_borrows: Vec<ContainerBorrow>,
104 pub iterator_borrows: Vec<IteratorBorrow>,
106 pub lambda_captures: Vec<LambdaCapture>,
108 pub member_borrows: Vec<MemberBorrow>,
110 pub heap_allocations: HashMap<String, HeapAllocation>,
112 pub user_defined_raii_types: HashSet<String>,
114 pub current_scope: usize,
116 pub variable_scopes: HashMap<String, usize>,
118 pub container_variables: HashSet<String>,
120 pub iterator_variables: HashSet<String>,
122}
123
124impl RaiiTracker {
125 pub fn new() -> Self {
126 Self {
127 container_borrows: Vec::new(),
128 iterator_borrows: Vec::new(),
129 lambda_captures: Vec::new(),
130 member_borrows: Vec::new(),
131 heap_allocations: HashMap::new(),
132 user_defined_raii_types: HashSet::new(),
133 current_scope: 0,
134 variable_scopes: HashMap::new(),
135 container_variables: HashSet::new(),
136 iterator_variables: HashSet::new(),
137 }
138 }
139
140 pub fn is_container_type(type_name: &str) -> bool {
142 type_name.contains("vector") ||
143 type_name.contains("Vector") ||
144 type_name.contains("Vec<") ||
145 type_name.contains("list") ||
146 type_name.contains("deque") ||
147 type_name.contains("set") ||
148 type_name.contains("map") ||
149 type_name.contains("unordered_") ||
150 type_name.contains("array<") ||
151 type_name.contains("span<")
152 }
153
154 pub fn is_iterator_type(type_name: &str) -> bool {
156 type_name.contains("iterator") ||
157 type_name.contains("Iterator") ||
158 type_name.ends_with("::iterator") ||
159 type_name.ends_with("::const_iterator") ||
160 type_name.ends_with("::reverse_iterator")
161 }
162
163 pub fn is_container_store_method(method_name: &str) -> bool {
165 method_name == "push_back" ||
166 method_name == "push_front" ||
167 method_name == "insert" ||
168 method_name == "emplace" ||
169 method_name == "emplace_back" ||
170 method_name == "emplace_front" ||
171 method_name == "assign"
172 }
173
174 pub fn is_iterator_returning_method(method_name: &str) -> bool {
176 method_name == "begin" ||
177 method_name == "end" ||
178 method_name == "cbegin" ||
179 method_name == "cend" ||
180 method_name == "rbegin" ||
181 method_name == "rend" ||
182 method_name == "find" ||
183 method_name == "lower_bound" ||
184 method_name == "upper_bound"
185 }
186
187 pub fn register_variable(&mut self, name: &str, type_name: &str, scope: usize) {
189 self.variable_scopes.insert(name.to_string(), scope);
190
191 if Self::is_container_type(type_name) {
192 self.container_variables.insert(name.to_string());
193 }
194
195 if Self::is_iterator_type(type_name) {
196 self.iterator_variables.insert(name.to_string());
197 }
198 }
199
200 pub fn record_container_store(&mut self, container: &str, pointee: &str, line: usize) {
202 let container_scope = *self.variable_scopes.get(container).unwrap_or(&0);
203 let pointee_scope = *self.variable_scopes.get(pointee).unwrap_or(&0);
204
205 self.container_borrows.push(ContainerBorrow {
206 container: container.to_string(),
207 pointee: pointee.to_string(),
208 container_scope,
209 pointee_scope,
210 line,
211 });
212 }
213
214 pub fn record_iterator_creation(&mut self, iterator: &str, container: &str, line: usize) {
216 let iterator_scope = self.current_scope;
217 let container_scope = *self.variable_scopes.get(container).unwrap_or(&0);
218
219 self.iterator_borrows.push(IteratorBorrow {
220 iterator: iterator.to_string(),
221 container: container.to_string(),
222 iterator_scope,
223 container_scope,
224 line,
225 });
226
227 self.iterator_variables.insert(iterator.to_string());
228 }
229
230 pub fn record_lambda(&mut self, lambda_var: &str, ref_captures: Vec<String>, line: usize) {
232 self.lambda_captures.push(LambdaCapture {
233 lambda_var: lambda_var.to_string(),
234 ref_captures,
235 lambda_scope: self.current_scope,
236 has_escaped: false,
237 line,
238 });
239 }
240
241 pub fn mark_lambda_escaped(&mut self, lambda_var: &str) {
243 for capture in &mut self.lambda_captures {
244 if capture.lambda_var == lambda_var {
245 capture.has_escaped = true;
246 }
247 }
248 }
249
250 pub fn record_member_borrow(&mut self, reference: &str, object: &str, field: &str, line: usize) {
253 let reference_scope = *self.variable_scopes.get(reference).unwrap_or(&self.current_scope);
255 let object_scope = *self.variable_scopes.get(object).unwrap_or(&0);
256
257 self.member_borrows.push(MemberBorrow {
258 reference: reference.to_string(),
259 object: object.to_string(),
260 field: field.to_string(),
261 reference_scope,
262 object_scope,
263 line,
264 });
265 }
266
267 pub fn record_allocation(&mut self, var: &str, line: usize) {
269 self.heap_allocations.insert(var.to_string(), HeapAllocation {
270 variable: var.to_string(),
271 state: AllocationState::Allocated,
272 allocation_line: line,
273 free_line: None,
274 });
275 }
276
277 pub fn record_deallocation(&mut self, var: &str, line: usize) -> Option<String> {
279 if let Some(alloc) = self.heap_allocations.get_mut(var) {
280 if alloc.state == AllocationState::Freed {
281 return Some(format!(
283 "Double free: '{}' was already freed at line {}",
284 var, alloc.free_line.unwrap_or(0)
285 ));
286 }
287 alloc.state = AllocationState::Freed;
288 alloc.free_line = Some(line);
289 }
290 None
291 }
292
293 pub fn is_freed(&self, var: &str) -> bool {
295 self.heap_allocations.get(var)
296 .map(|a| a.state == AllocationState::Freed)
297 .unwrap_or(false)
298 }
299
300 pub fn enter_scope(&mut self) {
302 self.current_scope += 1;
303 }
304
305 pub fn exit_scope(&mut self) -> Vec<String> {
307 let mut errors = Vec::new();
308 let dying_scope = self.current_scope;
309
310 for borrow in &self.container_borrows {
312 if borrow.pointee_scope == dying_scope && borrow.container_scope < dying_scope {
314 errors.push(format!(
315 "Dangling pointer in container: '{}' stored pointer to '{}' which goes out of scope (stored at line {})",
316 borrow.container, borrow.pointee, borrow.line
317 ));
318 }
319 }
320
321 for borrow in &self.iterator_borrows {
323 if borrow.container_scope == dying_scope && borrow.iterator_scope < dying_scope {
325 errors.push(format!(
326 "Iterator outlives container: '{}' borrows from '{}' which goes out of scope (created at line {})",
327 borrow.iterator, borrow.container, borrow.line
328 ));
329 }
330 }
331
332 for capture in &self.lambda_captures {
334 if capture.has_escaped {
335 for ref_var in &capture.ref_captures {
336 if self.variable_scopes.get(ref_var) == Some(&dying_scope) {
337 errors.push(format!(
338 "Lambda escape: lambda '{}' captures '{}' by reference, but '{}' goes out of scope (lambda at line {})",
339 capture.lambda_var, ref_var, ref_var, capture.line
340 ));
341 }
342 }
343 }
344 }
345
346 for borrow in &self.member_borrows {
348 if borrow.object_scope == dying_scope && borrow.reference_scope < dying_scope {
350 errors.push(format!(
351 "Dangling member reference: '{}' references '{}.{}' but '{}' goes out of scope (borrowed at line {})",
352 borrow.reference, borrow.object, borrow.field, borrow.object, borrow.line
353 ));
354 }
355 }
356
357 self.container_borrows.retain(|b| b.pointee_scope != dying_scope || b.container_scope >= dying_scope);
360 self.iterator_borrows.retain(|b| b.container_scope != dying_scope || b.iterator_scope >= dying_scope);
362 self.member_borrows.retain(|b| b.reference_scope != dying_scope && b.object_scope != dying_scope);
364
365 if self.current_scope > 0 {
367 self.current_scope -= 1;
368 }
369 errors
370 }
371}
372
373pub fn check_raii_issues(
375 function: &IrFunction,
376 _header_cache: &HeaderCache,
377) -> Result<Vec<String>, String> {
378 let mut errors = Vec::new();
379 let mut tracker = RaiiTracker::new();
380
381 for (name, info) in &function.variables {
383 let type_name = format!("{:?}", info.ty);
384 tracker.register_variable(name, &type_name, info.scope_level);
385 }
386
387 for node_idx in function.cfg.node_indices() {
389 let block = &function.cfg[node_idx];
390 for stmt in &block.statements {
391 let stmt_errors = process_raii_statement(stmt, &mut tracker, function);
392 errors.extend(stmt_errors);
393 }
394 }
395
396 Ok(errors)
397}
398
399fn process_raii_statement(
401 stmt: &IrStatement,
402 tracker: &mut RaiiTracker,
403 function: &IrFunction,
404) -> Vec<String> {
405 let mut errors = Vec::new();
406
407 match stmt {
408 IrStatement::EnterScope => {
409 tracker.enter_scope();
410 }
411
412 IrStatement::ExitScope => {
413 let scope_errors = tracker.exit_scope();
414 errors.extend(scope_errors);
415 }
416
417 IrStatement::CallExpr { func, args, result } => {
418 let method_name = func.split("::").last().unwrap_or(func);
420
421 if RaiiTracker::is_container_store_method(method_name) {
422 if let Some(container) = extract_receiver(func) {
425 for arg in args {
427 if arg.starts_with('&') {
429 let pointee = arg.trim_start_matches('&');
430 tracker.record_container_store(&container, pointee, 0);
431 }
432 }
433 }
434 }
435
436 if RaiiTracker::is_iterator_returning_method(method_name) {
438 if let (Some(result_var), Some(container)) = (result, extract_receiver(func)) {
439 tracker.record_iterator_creation(result_var, &container, 0);
440 }
441 }
442
443 if func == "operator new" || func.contains("::operator new") {
445 if let Some(result_var) = result {
446 tracker.record_allocation(result_var, 0);
447 }
448 }
449
450 if func == "operator delete" || func.contains("::operator delete") {
451 if let Some(arg) = args.first() {
452 if let Some(err) = tracker.record_deallocation(arg, 0) {
453 errors.push(err);
454 }
455 }
456 }
457 }
458
459 IrStatement::UseVariable { var, operation } => {
460 if tracker.is_freed(var) {
462 errors.push(format!(
463 "Use after free: variable '{}' has been freed (operation: {})",
464 var, operation
465 ));
466 }
467 }
468
469 IrStatement::Return { value } => {
470 if let Some(val) = value {
472 tracker.mark_lambda_escaped(val);
473 }
474 }
475
476 IrStatement::LambdaCapture { captures } => {
477 let ref_captures: Vec<String> = captures
478 .iter()
479 .filter(|c| c.is_ref)
480 .map(|c| c.name.clone())
481 .collect();
482
483 if !ref_captures.is_empty() {
484 tracker.record_lambda("_lambda", ref_captures, 0);
487 }
488 }
489
490 IrStatement::BorrowField { object, field, to, .. } => {
492 tracker.record_member_borrow(to, object, field, 0);
494 }
495
496 _ => {}
497 }
498
499 errors
500}
501
502fn extract_receiver(func: &str) -> Option<String> {
505 if func.contains("::") && !func.contains('.') {
507 return None;
509 }
510
511 if let Some(dot_pos) = func.rfind('.') {
513 return Some(func[..dot_pos].to_string());
514 }
515
516 None
517}
518
519pub fn has_user_defined_destructor(type_name: &str) -> bool {
522 if is_primitive_or_builtin(type_name) {
527 return false;
528 }
529
530 if type_name.starts_with("const ") ||
532 type_name.ends_with("&") ||
533 type_name.ends_with("*") {
534 return false;
535 }
536
537 !type_name.contains("::") ||
540 type_name.contains("std::") ||
541 type_name.starts_with("class ") ||
542 type_name.starts_with("struct ")
543}
544
545fn is_primitive_or_builtin(type_name: &str) -> bool {
546 let primitives = [
547 "int", "char", "bool", "float", "double", "void",
548 "long", "short", "unsigned", "signed",
549 "int8_t", "int16_t", "int32_t", "int64_t",
550 "uint8_t", "uint16_t", "uint32_t", "uint64_t",
551 "size_t", "ptrdiff_t", "nullptr_t",
552 ];
553
554 let base = type_name.split('<').next().unwrap_or(type_name).trim();
555 primitives.contains(&base)
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561
562 #[test]
563 fn test_is_container_type() {
564 assert!(RaiiTracker::is_container_type("std::vector<int>"));
565 assert!(RaiiTracker::is_container_type("std::map<int, int>"));
566 assert!(RaiiTracker::is_container_type("std::unordered_map<int, int>"));
567 assert!(!RaiiTracker::is_container_type("int"));
568 assert!(!RaiiTracker::is_container_type("std::string"));
569 }
570
571 #[test]
572 fn test_is_iterator_type() {
573 assert!(RaiiTracker::is_iterator_type("std::vector<int>::iterator"));
574 assert!(RaiiTracker::is_iterator_type("std::map<int,int>::const_iterator"));
575 assert!(!RaiiTracker::is_iterator_type("int*"));
576 }
577
578 #[test]
579 fn test_container_borrow_detection() {
580 let mut tracker = RaiiTracker::new();
581
582 tracker.register_variable("vec", "std::vector<int*>", 0);
584 tracker.register_variable("x", "int", 1);
585
586 tracker.current_scope = 1;
588 tracker.record_container_store("vec", "x", 10);
589
590 let errors = tracker.exit_scope();
592 assert_eq!(errors.len(), 1);
593 assert!(errors[0].contains("Dangling pointer"));
594 }
595
596 #[test]
597 fn test_iterator_outlives_container() {
598 let mut tracker = RaiiTracker::new();
599
600 tracker.current_scope = 1;
602 tracker.variable_scopes.insert("vec".to_string(), 1);
603 tracker.container_variables.insert("vec".to_string());
604
605 tracker.iterator_borrows.push(IteratorBorrow {
609 iterator: "it".to_string(),
610 container: "vec".to_string(),
611 iterator_scope: 0, container_scope: 1, line: 10,
614 });
615
616 let errors = tracker.exit_scope();
618 assert_eq!(errors.len(), 1);
619 assert!(errors[0].contains("Iterator outlives container"));
620 }
621
622 #[test]
623 fn test_double_free_detection() {
624 let mut tracker = RaiiTracker::new();
625
626 tracker.record_allocation("ptr", 10);
628
629 let err1 = tracker.record_deallocation("ptr", 20);
631 assert!(err1.is_none());
632
633 let err2 = tracker.record_deallocation("ptr", 30);
635 assert!(err2.is_some());
636 assert!(err2.unwrap().contains("Double free"));
637 }
638
639 #[test]
640 fn test_use_after_free() {
641 let mut tracker = RaiiTracker::new();
642
643 tracker.record_allocation("ptr", 10);
644 tracker.record_deallocation("ptr", 20);
645
646 assert!(tracker.is_freed("ptr"));
647 }
648}