rusty_cpp/analysis/
lambda_capture_safety.rs

1/// Lambda capture safety checking for @safe code with escape analysis
2///
3/// In @safe code:
4/// - Reference captures ([&], [&x]) are ALLOWED if the lambda doesn't escape
5/// - Reference captures that ESCAPE are FORBIDDEN - can create dangling references
6/// - Copy captures ([x], [=]) are ALWAYS ALLOWED - safe copy semantics
7/// - Move captures ([x = std::move(y)]) are ALWAYS ALLOWED - ownership transfer is safe
8/// - 'this' capture is FORBIDDEN - 'this' is a raw pointer that can dangle
9///
10/// Escape means:
11/// - Lambda is returned from function
12/// - Lambda is stored in a variable/container that outlives captured variables
13/// - Lambda is passed to a function that takes ownership (stores it)
14
15use crate::parser::{Function, Statement, Expression};
16use crate::parser::ast_visitor::LambdaCaptureKind;
17use crate::parser::safety_annotations::SafetyMode;
18use crate::debug_println;
19use std::collections::{HashMap, HashSet};
20
21/// Context for tracking lambdas and their escape status
22#[derive(Debug)]
23struct LambdaContext {
24    /// Map from variable name to captured references
25    lambda_ref_captures: HashMap<String, Vec<String>>,
26    /// Map from variable name to the capture kind (for error messages)
27    lambda_has_default_ref: HashMap<String, bool>,
28    /// Set of lambdas that have escaped
29    escaped_lambdas: HashSet<String>,
30    /// Current scope depth
31    scope_depth: usize,
32    /// Scope depth when each lambda was created
33    lambda_scopes: HashMap<String, usize>,
34    /// Variables and their scope depths
35    variable_scopes: HashMap<String, usize>,
36}
37
38impl LambdaContext {
39    fn new() -> Self {
40        Self {
41            lambda_ref_captures: HashMap::new(),
42            lambda_has_default_ref: HashMap::new(),
43            escaped_lambdas: HashSet::new(),
44            scope_depth: 0,
45            lambda_scopes: HashMap::new(),
46            variable_scopes: HashMap::new(),
47        }
48    }
49
50    fn enter_scope(&mut self) {
51        self.scope_depth += 1;
52    }
53
54    fn exit_scope(&mut self) {
55        if self.scope_depth > 0 {
56            self.scope_depth -= 1;
57        }
58    }
59
60    fn register_variable(&mut self, name: &str) {
61        self.variable_scopes.insert(name.to_string(), self.scope_depth);
62    }
63
64    fn register_lambda(&mut self, name: &str, ref_captures: Vec<String>, has_default_ref: bool) {
65        self.lambda_ref_captures.insert(name.to_string(), ref_captures);
66        self.lambda_has_default_ref.insert(name.to_string(), has_default_ref);
67        self.lambda_scopes.insert(name.to_string(), self.scope_depth);
68    }
69
70    fn mark_escaped(&mut self, name: &str) {
71        self.escaped_lambdas.insert(name.to_string());
72    }
73
74    fn get_escaped_lambdas_with_ref_captures(&self) -> Vec<(String, Vec<String>, bool)> {
75        self.escaped_lambdas
76            .iter()
77            .filter_map(|name| {
78                let captures = self.lambda_ref_captures.get(name)?;
79                if captures.is_empty() && !self.lambda_has_default_ref.get(name).unwrap_or(&false) {
80                    None
81                } else {
82                    Some((
83                        name.clone(),
84                        captures.clone(),
85                        *self.lambda_has_default_ref.get(name).unwrap_or(&false),
86                    ))
87                }
88            })
89            .collect()
90    }
91}
92
93/// Check a parsed function for lambda capture safety violations
94pub fn check_lambda_capture_safety(
95    function: &Function,
96    function_safety: SafetyMode,
97) -> Vec<String> {
98    let mut errors = Vec::new();
99
100    // Only check @safe functions
101    if function_safety != SafetyMode::Safe {
102        debug_println!("DEBUG LAMBDA: Skipping function '{}' (not @safe)", function.name);
103        return errors;
104    }
105
106    debug_println!("DEBUG LAMBDA: Checking function '{}' for lambda capture safety", function.name);
107
108    // Track if we're inside an @unsafe block
109    let mut unsafe_depth = 0;
110    let mut lambda_context = LambdaContext::new();
111
112    // First pass: collect all lambda definitions and track escapes
113    collect_lambdas_and_escapes(&function.body, &function.name, &mut lambda_context, &mut unsafe_depth);
114
115    // Check for 'this' captures (always forbidden)
116    check_this_captures_errors(&function.body, &function.name, &mut errors, &mut 0);
117
118    // Report errors for escaped lambdas with reference captures
119    for (lambda_name, ref_captures, has_default_ref) in lambda_context.get_escaped_lambdas_with_ref_captures() {
120        if has_default_ref {
121            errors.push(format!(
122                "Reference capture in @safe code: Lambda '{}' escapes but uses default reference capture [&] which can create dangling references - use copy capture [=] instead",
123                lambda_name
124            ));
125        } else {
126            for capture in ref_captures {
127                errors.push(format!(
128                    "Reference capture in @safe code: Lambda '{}' escapes but captures '{}' by reference ([&{}]) which can create dangling references - use copy capture [{}] instead",
129                    lambda_name, capture, capture, capture
130                ));
131            }
132        }
133    }
134
135    errors
136}
137
138fn collect_lambdas_and_escapes(
139    statements: &[Statement],
140    function_name: &str,
141    ctx: &mut LambdaContext,
142    unsafe_depth: &mut usize,
143) {
144    for stmt in statements {
145        match stmt {
146            Statement::EnterUnsafe => {
147                *unsafe_depth += 1;
148            }
149            Statement::ExitUnsafe => {
150                if *unsafe_depth > 0 {
151                    *unsafe_depth -= 1;
152                }
153            }
154            Statement::EnterScope => {
155                ctx.enter_scope();
156            }
157            Statement::ExitScope => {
158                ctx.exit_scope();
159            }
160            Statement::VariableDecl(var) => {
161                ctx.register_variable(&var.name);
162            }
163            Statement::Assignment { lhs, rhs, .. } => {
164                if *unsafe_depth == 0 {
165                    // Check if RHS is a lambda
166                    if let Some((ref_captures, has_default_ref)) = extract_lambda_captures(rhs) {
167                        // Extract variable name from lhs expression
168                        if let Some(lhs_name) = extract_variable_name(lhs) {
169                            ctx.register_lambda(&lhs_name, ref_captures, has_default_ref);
170                            debug_println!("DEBUG LAMBDA: Registered lambda '{}' in function '{}'", lhs_name, function_name);
171                        }
172                    }
173                }
174            }
175            Statement::Return(Some(expr)) => {
176                // Returning a lambda = escape
177                if *unsafe_depth == 0 {
178                    if let Some(var_name) = extract_variable_name(expr) {
179                        ctx.mark_escaped(&var_name);
180                        debug_println!("DEBUG LAMBDA: Lambda '{}' escapes via return in '{}'", var_name, function_name);
181                    }
182                    // Check if returning a lambda expression directly
183                    if let Some((ref_captures, has_default_ref)) = extract_lambda_captures(expr) {
184                        // Anonymous lambda being returned - this is an escape
185                        let lambda_name = format!("_anon_lambda_{}", statements.len());
186                        ctx.register_lambda(&lambda_name, ref_captures, has_default_ref);
187                        ctx.mark_escaped(&lambda_name);
188                        debug_println!("DEBUG LAMBDA: Anonymous lambda escapes via return in '{}'", function_name);
189                    }
190                }
191            }
192            Statement::ExpressionStatement { expr, .. } => {
193                if *unsafe_depth == 0 {
194                    // Check for function calls that might store the lambda
195                    check_for_escape_via_call(expr, ctx, function_name);
196                }
197            }
198            Statement::If { then_branch, else_branch, .. } => {
199                collect_lambdas_and_escapes(then_branch, function_name, ctx, unsafe_depth);
200                if let Some(else_stmts) = else_branch {
201                    collect_lambdas_and_escapes(else_stmts, function_name, ctx, unsafe_depth);
202                }
203            }
204            Statement::Block(inner_stmts) => {
205                collect_lambdas_and_escapes(inner_stmts, function_name, ctx, unsafe_depth);
206            }
207            _ => {}
208        }
209    }
210}
211
212fn check_this_captures_errors(
213    statements: &[Statement],
214    function_name: &str,
215    errors: &mut Vec<String>,
216    unsafe_depth: &mut usize,
217) {
218    for stmt in statements {
219        match stmt {
220            Statement::EnterUnsafe => {
221                *unsafe_depth += 1;
222            }
223            Statement::ExitUnsafe => {
224                if *unsafe_depth > 0 {
225                    *unsafe_depth -= 1;
226                }
227            }
228            Statement::Assignment { rhs, location, .. } => {
229                if *unsafe_depth == 0 {
230                    check_expression_for_this_capture(rhs, function_name, location, errors);
231                }
232            }
233            Statement::ExpressionStatement { expr, location } => {
234                if *unsafe_depth == 0 {
235                    check_expression_for_this_capture(expr, function_name, location, errors);
236                }
237            }
238            Statement::Return(Some(expr)) => {
239                if *unsafe_depth == 0 {
240                    let default_location = crate::parser::ast_visitor::SourceLocation {
241                        file: "unknown".to_string(),
242                        line: 0,
243                        column: 0,
244                    };
245                    check_expression_for_this_capture(expr, function_name, &default_location, errors);
246                }
247            }
248            Statement::If { then_branch, else_branch, .. } => {
249                check_this_captures_errors(then_branch, function_name, errors, unsafe_depth);
250                if let Some(else_stmts) = else_branch {
251                    check_this_captures_errors(else_stmts, function_name, errors, unsafe_depth);
252                }
253            }
254            Statement::Block(inner_stmts) => {
255                check_this_captures_errors(inner_stmts, function_name, errors, unsafe_depth);
256            }
257            _ => {}
258        }
259    }
260}
261
262fn check_expression_for_this_capture(
263    expr: &Expression,
264    _function_name: &str,
265    location: &crate::parser::ast_visitor::SourceLocation,
266    errors: &mut Vec<String>,
267) {
268    match expr {
269        Expression::Lambda { captures } => {
270            for capture in captures {
271                if matches!(capture, LambdaCaptureKind::This) {
272                    errors.push(format!(
273                        "Reference capture in @safe code at {}:{}: Capturing 'this' is forbidden in @safe code - 'this' is a raw pointer that can dangle",
274                        location.file, location.line
275                    ));
276                }
277            }
278        }
279        Expression::FunctionCall { args, .. } => {
280            for arg in args {
281                check_expression_for_this_capture(arg, _function_name, location, errors);
282            }
283        }
284        Expression::Move { inner, .. } |
285        Expression::Dereference(inner) |
286        Expression::AddressOf(inner) => {
287            check_expression_for_this_capture(inner, _function_name, location, errors);
288        }
289        Expression::BinaryOp { left, right, .. } => {
290            check_expression_for_this_capture(left, _function_name, location, errors);
291            check_expression_for_this_capture(right, _function_name, location, errors);
292        }
293        Expression::MemberAccess { object, .. } => {
294            check_expression_for_this_capture(object, _function_name, location, errors);
295        }
296        _ => {}
297    }
298}
299
300fn check_for_escape_via_call(
301    expr: &Expression,
302    ctx: &mut LambdaContext,
303    function_name: &str,
304) {
305    if let Expression::FunctionCall { name, args, .. } = expr {
306        // Check if passing a lambda to a function that stores it
307        // For now, we consider push_back, emplace_back, insert, etc. as escaping
308        let storing_methods = ["push_back", "emplace_back", "push_front", "emplace_front",
309                              "insert", "emplace", "assign", "store"];
310
311        let method_name = name.split("::").last().unwrap_or(name);
312        let method_name = method_name.split('.').last().unwrap_or(method_name);
313
314        if storing_methods.iter().any(|&m| method_name.contains(m)) {
315            for arg in args {
316                if let Some(var_name) = extract_variable_name(arg) {
317                    ctx.mark_escaped(&var_name);
318                    debug_println!("DEBUG LAMBDA: Lambda '{}' potentially escapes via {} in '{}'",
319                        var_name, name, function_name);
320                }
321            }
322        }
323
324        // Recursively check nested calls
325        for arg in args {
326            check_for_escape_via_call(arg, ctx, function_name);
327        }
328    }
329}
330
331fn extract_lambda_captures(expr: &Expression) -> Option<(Vec<String>, bool)> {
332    match expr {
333        Expression::Lambda { captures } => {
334            let mut ref_captures = Vec::new();
335            let mut has_default_ref = false;
336
337            for capture in captures {
338                match capture {
339                    LambdaCaptureKind::DefaultRef => {
340                        has_default_ref = true;
341                    }
342                    LambdaCaptureKind::ByRef(var_name) => {
343                        ref_captures.push(var_name.clone());
344                    }
345                    _ => {}
346                }
347            }
348
349            Some((ref_captures, has_default_ref))
350        }
351        _ => None,
352    }
353}
354
355fn extract_variable_name(expr: &Expression) -> Option<String> {
356    match expr {
357        Expression::Variable(name) => Some(name.clone()),
358        Expression::Move { inner, .. } => extract_variable_name(inner),
359        _ => None,
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use crate::parser::ast_visitor::SourceLocation;
367
368    fn make_location() -> SourceLocation {
369        SourceLocation {
370            file: "test.cpp".to_string(),
371            line: 1,
372            column: 1,
373        }
374    }
375
376    #[test]
377    fn test_this_capture_in_safe_is_error() {
378        let lambda = Expression::Lambda {
379            captures: vec![LambdaCaptureKind::This],
380        };
381
382        let mut errors = Vec::new();
383        check_expression_for_this_capture(&lambda, "test", &make_location(), &mut errors);
384
385        assert_eq!(errors.len(), 1);
386        assert!(errors[0].contains("this"));
387    }
388
389    #[test]
390    fn test_copy_capture_in_safe_is_ok() {
391        let lambda = Expression::Lambda {
392            captures: vec![LambdaCaptureKind::ByCopy("x".to_string())],
393        };
394
395        // Copy captures never cause errors
396        let (ref_captures, has_default_ref) = extract_lambda_captures(&lambda).unwrap();
397        assert!(ref_captures.is_empty());
398        assert!(!has_default_ref);
399    }
400
401    #[test]
402    fn test_default_copy_capture_in_safe_is_ok() {
403        let lambda = Expression::Lambda {
404            captures: vec![LambdaCaptureKind::DefaultCopy],
405        };
406
407        let (ref_captures, has_default_ref) = extract_lambda_captures(&lambda).unwrap();
408        assert!(ref_captures.is_empty());
409        assert!(!has_default_ref);
410    }
411
412    #[test]
413    fn test_init_capture_in_safe_is_ok() {
414        let lambda = Expression::Lambda {
415            captures: vec![LambdaCaptureKind::Init {
416                name: "y".to_string(),
417                is_move: true,
418            }],
419        };
420
421        let (ref_captures, has_default_ref) = extract_lambda_captures(&lambda).unwrap();
422        assert!(ref_captures.is_empty());
423        assert!(!has_default_ref);
424    }
425
426    #[test]
427    fn test_ref_capture_extraction() {
428        let lambda = Expression::Lambda {
429            captures: vec![LambdaCaptureKind::ByRef("x".to_string())],
430        };
431
432        let (ref_captures, has_default_ref) = extract_lambda_captures(&lambda).unwrap();
433        assert_eq!(ref_captures, vec!["x"]);
434        assert!(!has_default_ref);
435    }
436
437    #[test]
438    fn test_default_ref_capture_extraction() {
439        let lambda = Expression::Lambda {
440            captures: vec![LambdaCaptureKind::DefaultRef],
441        };
442
443        let (ref_captures, has_default_ref) = extract_lambda_captures(&lambda).unwrap();
444        assert!(ref_captures.is_empty());
445        assert!(has_default_ref);
446    }
447
448    #[test]
449    fn test_lambda_context_escape_tracking() {
450        let mut ctx = LambdaContext::new();
451
452        // Register a lambda with ref captures
453        ctx.register_lambda("lambda1", vec!["x".to_string()], false);
454
455        // Not escaped yet
456        assert!(ctx.get_escaped_lambdas_with_ref_captures().is_empty());
457
458        // Mark as escaped
459        ctx.mark_escaped("lambda1");
460
461        // Now it should be reported
462        let escaped = ctx.get_escaped_lambdas_with_ref_captures();
463        assert_eq!(escaped.len(), 1);
464        assert_eq!(escaped[0].0, "lambda1");
465        assert_eq!(escaped[0].1, vec!["x"]);
466    }
467
468    #[test]
469    fn test_lambda_context_no_escape_no_error() {
470        let mut ctx = LambdaContext::new();
471
472        // Register a lambda with ref captures but don't mark as escaped
473        ctx.register_lambda("lambda1", vec!["x".to_string()], false);
474
475        // Should not be reported
476        let escaped = ctx.get_escaped_lambdas_with_ref_captures();
477        assert!(escaped.is_empty());
478    }
479
480    #[test]
481    fn test_lambda_context_escape_no_ref_capture() {
482        let mut ctx = LambdaContext::new();
483
484        // Register a lambda WITHOUT ref captures
485        ctx.register_lambda("lambda1", vec![], false);
486        ctx.mark_escaped("lambda1");
487
488        // Should not be reported (no ref captures)
489        let escaped = ctx.get_escaped_lambdas_with_ref_captures();
490        assert!(escaped.is_empty());
491    }
492}