1use crate::parser::{Function, Statement, Expression};
16use crate::parser::ast_visitor::LambdaCaptureKind;
17use crate::parser::safety_annotations::SafetyMode;
18use crate::debug_println;
19use std::collections::{HashMap, HashSet};
20
21#[derive(Debug)]
23struct LambdaContext {
24 lambda_ref_captures: HashMap<String, Vec<String>>,
26 lambda_has_default_ref: HashMap<String, bool>,
28 escaped_lambdas: HashSet<String>,
30 scope_depth: usize,
32 lambda_scopes: HashMap<String, usize>,
34 variable_scopes: HashMap<String, usize>,
36}
37
38impl LambdaContext {
39 fn new() -> Self {
40 Self {
41 lambda_ref_captures: HashMap::new(),
42 lambda_has_default_ref: HashMap::new(),
43 escaped_lambdas: HashSet::new(),
44 scope_depth: 0,
45 lambda_scopes: HashMap::new(),
46 variable_scopes: HashMap::new(),
47 }
48 }
49
50 fn enter_scope(&mut self) {
51 self.scope_depth += 1;
52 }
53
54 fn exit_scope(&mut self) {
55 if self.scope_depth > 0 {
56 self.scope_depth -= 1;
57 }
58 }
59
60 fn register_variable(&mut self, name: &str) {
61 self.variable_scopes.insert(name.to_string(), self.scope_depth);
62 }
63
64 fn register_lambda(&mut self, name: &str, ref_captures: Vec<String>, has_default_ref: bool) {
65 self.lambda_ref_captures.insert(name.to_string(), ref_captures);
66 self.lambda_has_default_ref.insert(name.to_string(), has_default_ref);
67 self.lambda_scopes.insert(name.to_string(), self.scope_depth);
68 }
69
70 fn mark_escaped(&mut self, name: &str) {
71 self.escaped_lambdas.insert(name.to_string());
72 }
73
74 fn get_escaped_lambdas_with_ref_captures(&self) -> Vec<(String, Vec<String>, bool)> {
75 self.escaped_lambdas
76 .iter()
77 .filter_map(|name| {
78 let captures = self.lambda_ref_captures.get(name)?;
79 if captures.is_empty() && !self.lambda_has_default_ref.get(name).unwrap_or(&false) {
80 None
81 } else {
82 Some((
83 name.clone(),
84 captures.clone(),
85 *self.lambda_has_default_ref.get(name).unwrap_or(&false),
86 ))
87 }
88 })
89 .collect()
90 }
91}
92
93pub fn check_lambda_capture_safety(
95 function: &Function,
96 function_safety: SafetyMode,
97) -> Vec<String> {
98 let mut errors = Vec::new();
99
100 if function_safety != SafetyMode::Safe {
102 debug_println!("DEBUG LAMBDA: Skipping function '{}' (not @safe)", function.name);
103 return errors;
104 }
105
106 debug_println!("DEBUG LAMBDA: Checking function '{}' for lambda capture safety", function.name);
107
108 let mut unsafe_depth = 0;
110 let mut lambda_context = LambdaContext::new();
111
112 collect_lambdas_and_escapes(&function.body, &function.name, &mut lambda_context, &mut unsafe_depth);
114
115 check_this_captures_errors(&function.body, &function.name, &mut errors, &mut 0);
117
118 for (lambda_name, ref_captures, has_default_ref) in lambda_context.get_escaped_lambdas_with_ref_captures() {
120 if has_default_ref {
121 errors.push(format!(
122 "Reference capture in @safe code: Lambda '{}' escapes but uses default reference capture [&] which can create dangling references - use copy capture [=] instead",
123 lambda_name
124 ));
125 } else {
126 for capture in ref_captures {
127 errors.push(format!(
128 "Reference capture in @safe code: Lambda '{}' escapes but captures '{}' by reference ([&{}]) which can create dangling references - use copy capture [{}] instead",
129 lambda_name, capture, capture, capture
130 ));
131 }
132 }
133 }
134
135 errors
136}
137
138fn collect_lambdas_and_escapes(
139 statements: &[Statement],
140 function_name: &str,
141 ctx: &mut LambdaContext,
142 unsafe_depth: &mut usize,
143) {
144 for stmt in statements {
145 match stmt {
146 Statement::EnterUnsafe => {
147 *unsafe_depth += 1;
148 }
149 Statement::ExitUnsafe => {
150 if *unsafe_depth > 0 {
151 *unsafe_depth -= 1;
152 }
153 }
154 Statement::EnterScope => {
155 ctx.enter_scope();
156 }
157 Statement::ExitScope => {
158 ctx.exit_scope();
159 }
160 Statement::VariableDecl(var) => {
161 ctx.register_variable(&var.name);
162 }
163 Statement::Assignment { lhs, rhs, .. } => {
164 if *unsafe_depth == 0 {
165 if let Some((ref_captures, has_default_ref)) = extract_lambda_captures(rhs) {
167 if let Some(lhs_name) = extract_variable_name(lhs) {
169 ctx.register_lambda(&lhs_name, ref_captures, has_default_ref);
170 debug_println!("DEBUG LAMBDA: Registered lambda '{}' in function '{}'", lhs_name, function_name);
171 }
172 }
173 }
174 }
175 Statement::Return(Some(expr)) => {
176 if *unsafe_depth == 0 {
178 if let Some(var_name) = extract_variable_name(expr) {
179 ctx.mark_escaped(&var_name);
180 debug_println!("DEBUG LAMBDA: Lambda '{}' escapes via return in '{}'", var_name, function_name);
181 }
182 if let Some((ref_captures, has_default_ref)) = extract_lambda_captures(expr) {
184 let lambda_name = format!("_anon_lambda_{}", statements.len());
186 ctx.register_lambda(&lambda_name, ref_captures, has_default_ref);
187 ctx.mark_escaped(&lambda_name);
188 debug_println!("DEBUG LAMBDA: Anonymous lambda escapes via return in '{}'", function_name);
189 }
190 }
191 }
192 Statement::ExpressionStatement { expr, .. } => {
193 if *unsafe_depth == 0 {
194 check_for_escape_via_call(expr, ctx, function_name);
196 }
197 }
198 Statement::If { then_branch, else_branch, .. } => {
199 collect_lambdas_and_escapes(then_branch, function_name, ctx, unsafe_depth);
200 if let Some(else_stmts) = else_branch {
201 collect_lambdas_and_escapes(else_stmts, function_name, ctx, unsafe_depth);
202 }
203 }
204 Statement::Block(inner_stmts) => {
205 collect_lambdas_and_escapes(inner_stmts, function_name, ctx, unsafe_depth);
206 }
207 _ => {}
208 }
209 }
210}
211
212fn check_this_captures_errors(
213 statements: &[Statement],
214 function_name: &str,
215 errors: &mut Vec<String>,
216 unsafe_depth: &mut usize,
217) {
218 for stmt in statements {
219 match stmt {
220 Statement::EnterUnsafe => {
221 *unsafe_depth += 1;
222 }
223 Statement::ExitUnsafe => {
224 if *unsafe_depth > 0 {
225 *unsafe_depth -= 1;
226 }
227 }
228 Statement::Assignment { rhs, location, .. } => {
229 if *unsafe_depth == 0 {
230 check_expression_for_this_capture(rhs, function_name, location, errors);
231 }
232 }
233 Statement::ExpressionStatement { expr, location } => {
234 if *unsafe_depth == 0 {
235 check_expression_for_this_capture(expr, function_name, location, errors);
236 }
237 }
238 Statement::Return(Some(expr)) => {
239 if *unsafe_depth == 0 {
240 let default_location = crate::parser::ast_visitor::SourceLocation {
241 file: "unknown".to_string(),
242 line: 0,
243 column: 0,
244 };
245 check_expression_for_this_capture(expr, function_name, &default_location, errors);
246 }
247 }
248 Statement::If { then_branch, else_branch, .. } => {
249 check_this_captures_errors(then_branch, function_name, errors, unsafe_depth);
250 if let Some(else_stmts) = else_branch {
251 check_this_captures_errors(else_stmts, function_name, errors, unsafe_depth);
252 }
253 }
254 Statement::Block(inner_stmts) => {
255 check_this_captures_errors(inner_stmts, function_name, errors, unsafe_depth);
256 }
257 _ => {}
258 }
259 }
260}
261
262fn check_expression_for_this_capture(
263 expr: &Expression,
264 _function_name: &str,
265 location: &crate::parser::ast_visitor::SourceLocation,
266 errors: &mut Vec<String>,
267) {
268 match expr {
269 Expression::Lambda { captures } => {
270 for capture in captures {
271 if matches!(capture, LambdaCaptureKind::This) {
272 errors.push(format!(
273 "Reference capture in @safe code at {}:{}: Capturing 'this' is forbidden in @safe code - 'this' is a raw pointer that can dangle",
274 location.file, location.line
275 ));
276 }
277 }
278 }
279 Expression::FunctionCall { args, .. } => {
280 for arg in args {
281 check_expression_for_this_capture(arg, _function_name, location, errors);
282 }
283 }
284 Expression::Move { inner, .. } |
285 Expression::Dereference(inner) |
286 Expression::AddressOf(inner) => {
287 check_expression_for_this_capture(inner, _function_name, location, errors);
288 }
289 Expression::BinaryOp { left, right, .. } => {
290 check_expression_for_this_capture(left, _function_name, location, errors);
291 check_expression_for_this_capture(right, _function_name, location, errors);
292 }
293 Expression::MemberAccess { object, .. } => {
294 check_expression_for_this_capture(object, _function_name, location, errors);
295 }
296 _ => {}
297 }
298}
299
300fn check_for_escape_via_call(
301 expr: &Expression,
302 ctx: &mut LambdaContext,
303 function_name: &str,
304) {
305 if let Expression::FunctionCall { name, args, .. } = expr {
306 let storing_methods = ["push_back", "emplace_back", "push_front", "emplace_front",
309 "insert", "emplace", "assign", "store"];
310
311 let method_name = name.split("::").last().unwrap_or(name);
312 let method_name = method_name.split('.').last().unwrap_or(method_name);
313
314 if storing_methods.iter().any(|&m| method_name.contains(m)) {
315 for arg in args {
316 if let Some(var_name) = extract_variable_name(arg) {
317 ctx.mark_escaped(&var_name);
318 debug_println!("DEBUG LAMBDA: Lambda '{}' potentially escapes via {} in '{}'",
319 var_name, name, function_name);
320 }
321 }
322 }
323
324 for arg in args {
326 check_for_escape_via_call(arg, ctx, function_name);
327 }
328 }
329}
330
331fn extract_lambda_captures(expr: &Expression) -> Option<(Vec<String>, bool)> {
332 match expr {
333 Expression::Lambda { captures } => {
334 let mut ref_captures = Vec::new();
335 let mut has_default_ref = false;
336
337 for capture in captures {
338 match capture {
339 LambdaCaptureKind::DefaultRef => {
340 has_default_ref = true;
341 }
342 LambdaCaptureKind::ByRef(var_name) => {
343 ref_captures.push(var_name.clone());
344 }
345 _ => {}
346 }
347 }
348
349 Some((ref_captures, has_default_ref))
350 }
351 _ => None,
352 }
353}
354
355fn extract_variable_name(expr: &Expression) -> Option<String> {
356 match expr {
357 Expression::Variable(name) => Some(name.clone()),
358 Expression::Move { inner, .. } => extract_variable_name(inner),
359 _ => None,
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use crate::parser::ast_visitor::SourceLocation;
367
368 fn make_location() -> SourceLocation {
369 SourceLocation {
370 file: "test.cpp".to_string(),
371 line: 1,
372 column: 1,
373 }
374 }
375
376 #[test]
377 fn test_this_capture_in_safe_is_error() {
378 let lambda = Expression::Lambda {
379 captures: vec![LambdaCaptureKind::This],
380 };
381
382 let mut errors = Vec::new();
383 check_expression_for_this_capture(&lambda, "test", &make_location(), &mut errors);
384
385 assert_eq!(errors.len(), 1);
386 assert!(errors[0].contains("this"));
387 }
388
389 #[test]
390 fn test_copy_capture_in_safe_is_ok() {
391 let lambda = Expression::Lambda {
392 captures: vec![LambdaCaptureKind::ByCopy("x".to_string())],
393 };
394
395 let (ref_captures, has_default_ref) = extract_lambda_captures(&lambda).unwrap();
397 assert!(ref_captures.is_empty());
398 assert!(!has_default_ref);
399 }
400
401 #[test]
402 fn test_default_copy_capture_in_safe_is_ok() {
403 let lambda = Expression::Lambda {
404 captures: vec![LambdaCaptureKind::DefaultCopy],
405 };
406
407 let (ref_captures, has_default_ref) = extract_lambda_captures(&lambda).unwrap();
408 assert!(ref_captures.is_empty());
409 assert!(!has_default_ref);
410 }
411
412 #[test]
413 fn test_init_capture_in_safe_is_ok() {
414 let lambda = Expression::Lambda {
415 captures: vec![LambdaCaptureKind::Init {
416 name: "y".to_string(),
417 is_move: true,
418 }],
419 };
420
421 let (ref_captures, has_default_ref) = extract_lambda_captures(&lambda).unwrap();
422 assert!(ref_captures.is_empty());
423 assert!(!has_default_ref);
424 }
425
426 #[test]
427 fn test_ref_capture_extraction() {
428 let lambda = Expression::Lambda {
429 captures: vec![LambdaCaptureKind::ByRef("x".to_string())],
430 };
431
432 let (ref_captures, has_default_ref) = extract_lambda_captures(&lambda).unwrap();
433 assert_eq!(ref_captures, vec!["x"]);
434 assert!(!has_default_ref);
435 }
436
437 #[test]
438 fn test_default_ref_capture_extraction() {
439 let lambda = Expression::Lambda {
440 captures: vec![LambdaCaptureKind::DefaultRef],
441 };
442
443 let (ref_captures, has_default_ref) = extract_lambda_captures(&lambda).unwrap();
444 assert!(ref_captures.is_empty());
445 assert!(has_default_ref);
446 }
447
448 #[test]
449 fn test_lambda_context_escape_tracking() {
450 let mut ctx = LambdaContext::new();
451
452 ctx.register_lambda("lambda1", vec!["x".to_string()], false);
454
455 assert!(ctx.get_escaped_lambdas_with_ref_captures().is_empty());
457
458 ctx.mark_escaped("lambda1");
460
461 let escaped = ctx.get_escaped_lambdas_with_ref_captures();
463 assert_eq!(escaped.len(), 1);
464 assert_eq!(escaped[0].0, "lambda1");
465 assert_eq!(escaped[0].1, vec!["x"]);
466 }
467
468 #[test]
469 fn test_lambda_context_no_escape_no_error() {
470 let mut ctx = LambdaContext::new();
471
472 ctx.register_lambda("lambda1", vec!["x".to_string()], false);
474
475 let escaped = ctx.get_escaped_lambdas_with_ref_captures();
477 assert!(escaped.is_empty());
478 }
479
480 #[test]
481 fn test_lambda_context_escape_no_ref_capture() {
482 let mut ctx = LambdaContext::new();
483
484 ctx.register_lambda("lambda1", vec![], false);
486 ctx.mark_escaped("lambda1");
487
488 let escaped = ctx.get_escaped_lambdas_with_ref_captures();
490 assert!(escaped.is_empty());
491 }
492}