1use crate::parser::{Function, Statement, Expression};
2use crate::parser::safety_annotations::{SafetyContext, SafetyMode};
3use crate::parser::external_annotations::ExternalAnnotations;
4use std::collections::HashSet;
5
6pub fn check_unsafe_propagation(
13 function: &Function,
14 safety_context: &SafetyContext,
15 known_safe_functions: &HashSet<String>,
16) -> Vec<String> {
17 check_unsafe_propagation_with_external(function, safety_context, known_safe_functions, None)
18}
19
20pub fn check_unsafe_propagation_with_external(
22 function: &Function,
23 safety_context: &SafetyContext,
24 known_safe_functions: &HashSet<String>,
25 external_annotations: Option<&ExternalAnnotations>,
26) -> Vec<String> {
27 let mut errors = Vec::new();
28 let mut unsafe_depth = 0;
29
30 let callable_params = get_callable_parameters(&function.parameters, &function.template_parameters);
33
34 for stmt in &function.body {
36 match stmt {
38 Statement::EnterUnsafe => {
39 unsafe_depth += 1;
40 continue;
41 }
42 Statement::ExitUnsafe => {
43 if unsafe_depth > 0 {
44 unsafe_depth -= 1;
45 }
46 continue;
47 }
48 _ => {}
49 }
50
51 let in_unsafe_scope = unsafe_depth > 0;
53
54 if let Some(error) = check_statement_for_unsafe_calls_with_external(
55 stmt, safety_context, known_safe_functions, external_annotations,
56 &function.template_parameters, &callable_params, in_unsafe_scope
57 ) {
58 errors.push(format!("In function '{}': {}", function.name, error));
59 }
60 }
61
62 errors
63}
64
65fn get_callable_parameters(parameters: &[crate::parser::Variable], template_params: &[String]) -> HashSet<String> {
68 let mut callable_params = HashSet::new();
69
70 for param in parameters {
71 let type_name = ¶m.type_name;
74
75 for template_param in template_params {
76 if type_contains_template_param(type_name, template_param) {
79 callable_params.insert(param.name.clone());
80 break;
81 }
82 }
83 }
84
85 callable_params
86}
87
88fn type_contains_template_param(type_name: &str, template_param: &str) -> bool {
91 let type_clean = type_name.replace("const", "").replace("&&", "").replace("&", "")
94 .replace("*", "").replace(" ", "");
95
96 if type_clean == template_param {
98 return true;
99 }
100
101 let words: Vec<&str> = type_name.split(|c: char| !c.is_alphanumeric() && c != '_')
104 .filter(|s| !s.is_empty())
105 .collect();
106 words.contains(&template_param.as_ref())
107}
108
109fn check_statement_for_unsafe_calls(
110 stmt: &Statement,
111 safety_context: &SafetyContext,
112 known_safe_functions: &HashSet<String>,
113) -> Option<String> {
114 check_statement_for_unsafe_calls_with_external(stmt, safety_context, known_safe_functions, None, &[], &HashSet::new(), false)
115}
116
117fn is_template_parameter_like(name: &str, template_params: &[String]) -> bool {
124 if template_params.contains(&name.to_string()) {
126 return true;
127 }
128
129 if name.ends_with("...") {
132 let base_name = name.trim_end_matches("...").trim();
133 if template_params.contains(&base_name.to_string()) {
134 return true;
135 }
136 }
137
138 if name.ends_with("&&") || name.ends_with("&") {
141 let base_name = name.trim_end_matches('&').trim();
142 if template_params.contains(&base_name.to_string()) {
143 return true;
144 }
145 }
146
147 if name.len() <= 8 && name.len() > 0 {
150 if let Some(first_char) = name.chars().next() {
151 if first_char.is_uppercase() && name.chars().all(|c| c.is_alphanumeric() || c == '_') {
152 return true;
154 }
155 }
156 }
157
158 false
159}
160
161fn check_statements_with_unsafe_tracking(
163 statements: &[Statement],
164 safety_context: &SafetyContext,
165 known_safe_functions: &HashSet<String>,
166 external_annotations: Option<&ExternalAnnotations>,
167 template_params: &[String],
168 callable_params: &HashSet<String>,
169 initial_unsafe_depth: usize,
170) -> Vec<String> {
171 let mut errors = Vec::new();
172 let mut unsafe_depth = initial_unsafe_depth;
173
174 for stmt in statements {
175 match stmt {
177 Statement::EnterUnsafe => {
178 unsafe_depth += 1;
179 continue;
180 }
181 Statement::ExitUnsafe => {
182 if unsafe_depth > 0 {
183 unsafe_depth -= 1;
184 }
185 continue;
186 }
187 _ => {}
188 }
189
190 let in_unsafe_scope = unsafe_depth > 0;
191
192 if let Some(error) = check_statement_for_unsafe_calls_with_external(
193 stmt, safety_context, known_safe_functions, external_annotations,
194 template_params, callable_params, in_unsafe_scope
195 ) {
196 errors.push(error);
197 }
198 }
199
200 errors
201}
202
203fn check_statement_for_unsafe_calls_with_external(
204 stmt: &Statement,
205 safety_context: &SafetyContext,
206 known_safe_functions: &HashSet<String>,
207 external_annotations: Option<&ExternalAnnotations>,
208 template_params: &[String],
209 callable_params: &HashSet<String>,
210 in_unsafe_scope: bool,
211) -> Option<String> {
212 use crate::parser::Statement;
213
214 if in_unsafe_scope {
216 return None;
217 }
218
219 match stmt {
220 Statement::FunctionCall { name, location, .. } => {
221 if is_template_parameter_like(name, template_params) {
224 return None; }
226
227 if !template_params.is_empty() && name == "unknown" {
230 return None; }
232
233 if name == "operator()" || name.contains("operator()") {
237 return None; }
239
240 if callable_params.contains(name) {
245 return None; }
247 if let Some(simple_name) = name.rsplit("::").next() {
249 if callable_params.contains(simple_name) {
250 return None; }
252 }
253
254 let called_safety = get_called_function_safety(name, safety_context, known_safe_functions, external_annotations);
256
257 match called_safety {
258 SafetyMode::Safe => {
259 }
261 SafetyMode::Unsafe => {
262 return Some(format!(
265 "Calling non-safe function '{}' at line {} requires @unsafe {{ }} block",
266 name, location.line
267 ));
268 }
269 }
270 }
271 Statement::Assignment { rhs, location, .. } => {
272 if let Some(unsafe_func) = find_unsafe_function_call_with_external(rhs, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
274 return Some(format!(
275 "Calling unsafe function '{}' at line {} requires unsafe context",
276 unsafe_func, location.line
277 ));
278 }
279 }
280 Statement::Return(Some(expr)) => {
281 if let Some(unsafe_func) = find_unsafe_function_call_with_external(expr, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
283 return Some(format!(
284 "Calling unsafe function '{}' in return statement requires unsafe context",
285 unsafe_func
286 ));
287 }
288 }
289 Statement::If { condition, then_branch, else_branch, location } => {
290 if let Some(unsafe_func) = find_unsafe_function_call_with_external(condition, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
292 return Some(format!(
293 "Calling unsafe function '{}' in condition at line {} requires unsafe context",
294 unsafe_func, location.line
295 ));
296 }
297
298 let then_errors = check_statements_with_unsafe_tracking(
301 then_branch, safety_context, known_safe_functions, external_annotations,
302 template_params, callable_params, 0
303 );
304 if !then_errors.is_empty() {
305 return Some(then_errors.into_iter().next().unwrap());
306 }
307
308 if let Some(else_stmts) = else_branch {
309 let else_errors = check_statements_with_unsafe_tracking(
310 else_stmts, safety_context, known_safe_functions, external_annotations,
311 template_params, callable_params, 0
312 );
313 if !else_errors.is_empty() {
314 return Some(else_errors.into_iter().next().unwrap());
315 }
316 }
317 }
318 Statement::Block(statements) => {
319 let block_errors = check_statements_with_unsafe_tracking(
321 statements, safety_context, known_safe_functions, external_annotations,
322 template_params, callable_params, 0
323 );
324 if !block_errors.is_empty() {
325 return Some(block_errors.into_iter().next().unwrap());
326 }
327 }
328 _ => {}
329 }
330
331 None
332}
333
334fn find_unsafe_function_call(
335 expr: &Expression,
336 safety_context: &SafetyContext,
337 known_safe_functions: &HashSet<String>,
338) -> Option<String> {
339 find_unsafe_function_call_with_external(expr, safety_context, known_safe_functions, None, &[], &HashSet::new())
340}
341
342fn find_unsafe_function_call_with_external(
343 expr: &Expression,
344 safety_context: &SafetyContext,
345 known_safe_functions: &HashSet<String>,
346 external_annotations: Option<&ExternalAnnotations>,
347 template_params: &[String],
348 callable_params: &HashSet<String>,
349) -> Option<String> {
350 use crate::parser::Expression;
351
352 match expr {
353 Expression::FunctionCall { name, args } => {
354 if is_template_parameter_like(name, template_params) {
357 for arg in args {
360 if let Some(unsafe_func) = find_unsafe_function_call_with_external(arg, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
361 return Some(unsafe_func);
362 }
363 }
364 return None;
365 }
366
367 if !template_params.is_empty() && name == "unknown" {
370 for arg in args {
372 if let Some(unsafe_func) = find_unsafe_function_call_with_external(arg, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
373 return Some(unsafe_func);
374 }
375 }
376 return None; }
378
379 if name == "operator()" || name.contains("operator()") {
382 for arg in args {
384 if let Some(unsafe_func) = find_unsafe_function_call_with_external(arg, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
385 return Some(unsafe_func);
386 }
387 }
388 return None; }
390
391 let is_callable_param = callable_params.contains(name) ||
395 name.rsplit("::").next().map(|s| callable_params.contains(s)).unwrap_or(false);
396 if is_callable_param {
397 for arg in args {
399 if let Some(unsafe_func) = find_unsafe_function_call_with_external(arg, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
400 return Some(unsafe_func);
401 }
402 }
403 return None; }
405
406 let called_safety = get_called_function_safety(name, safety_context, known_safe_functions, external_annotations);
408
409 match called_safety {
414 SafetyMode::Safe => {
415 }
417 SafetyMode::Unsafe => {
418 return Some(format!("{} (non-safe - use @unsafe block)", name));
420 }
421 }
422
423 for arg in args {
425 if let Some(unsafe_func) = find_unsafe_function_call_with_external(arg, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
426 return Some(unsafe_func);
427 }
428 }
429 }
430 Expression::BinaryOp { left, right, .. } => {
431 if let Some(unsafe_func) = find_unsafe_function_call_with_external(left, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
433 return Some(unsafe_func);
434 }
435 if let Some(unsafe_func) = find_unsafe_function_call_with_external(right, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
436 return Some(unsafe_func);
437 }
438 }
439 Expression::Move { inner, .. } | Expression::Dereference(inner) | Expression::AddressOf(inner) => {
440 if let Some(unsafe_func) = find_unsafe_function_call_with_external(inner, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
442 return Some(unsafe_func);
443 }
444 }
445 _ => {}
446 }
447
448 None
449}
450
451fn is_function_safe(
452 func_name: &str,
453 safety_context: &SafetyContext,
454 known_safe_functions: &HashSet<String>,
455) -> bool {
456 is_function_safe_with_external(func_name, safety_context, known_safe_functions, None)
457}
458
459fn get_called_function_safety(
461 func_name: &str,
462 safety_context: &SafetyContext,
463 known_safe_functions: &HashSet<String>,
464 external_annotations: Option<&ExternalAnnotations>,
465) -> SafetyMode {
466 let local_safety = safety_context.get_function_safety(func_name);
468 if local_safety != SafetyMode::Unsafe {
469 return local_safety;
470 }
471
472 if known_safe_functions.contains(func_name) {
474 return SafetyMode::Safe;
475 }
476
477 if let Some(annotations) = external_annotations {
479 if let Some(is_safe) = annotations.is_function_safe(func_name) {
480 return if is_safe { SafetyMode::Safe } else { SafetyMode::Unsafe };
481 }
482 }
483
484 SafetyMode::Unsafe
486}
487
488fn is_function_safe_with_external(
489 func_name: &str,
490 safety_context: &SafetyContext,
491 known_safe_functions: &HashSet<String>,
492 external_annotations: Option<&ExternalAnnotations>,
493) -> bool {
494 get_called_function_safety(func_name, safety_context, known_safe_functions, external_annotations) == SafetyMode::Safe
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use crate::parser::{Statement, Expression, SourceLocation};
501
502 #[test]
503 fn test_detect_unsafe_function_call() {
504 let stmt = Statement::FunctionCall {
505 name: "unknown_func".to_string(),
506 args: vec![],
507 location: SourceLocation {
508 file: "test.cpp".to_string(),
509 line: 10,
510 column: 5,
511 },
512 };
513
514 let safety_context = SafetyContext::new();
515 let known_safe = HashSet::new();
516
517 let error = check_statement_for_unsafe_calls(&stmt, &safety_context, &known_safe);
518 assert!(error.is_some());
519 let error_msg = error.unwrap();
520 assert!(error_msg.contains("unknown_func"));
521 assert!(error_msg.contains("unsafe"));
522 }
523
524 #[test]
525 fn test_stl_functions_require_unsafe() {
526 let stmt = Statement::FunctionCall {
528 name: "std::move".to_string(),
529 args: vec![Expression::Variable("x".to_string())],
530 location: SourceLocation {
531 file: "test.cpp".to_string(),
532 line: 10,
533 column: 5,
534 },
535 };
536
537 let safety_context = SafetyContext::new();
538 let known_safe = HashSet::new();
539
540 let error = check_statement_for_unsafe_calls(&stmt, &safety_context, &known_safe);
541 assert!(error.is_some(), "std::move should require @unsafe block in safe code");
542 let error_msg = error.unwrap();
543 assert!(error_msg.contains("std::move"));
544 assert!(error_msg.contains("@unsafe"));
545 }
546
547 #[test]
548 fn test_known_safe_function() {
549 let stmt = Statement::FunctionCall {
550 name: "my_safe_func".to_string(),
551 args: vec![],
552 location: SourceLocation {
553 file: "test.cpp".to_string(),
554 line: 10,
555 column: 5,
556 },
557 };
558
559 let safety_context = SafetyContext::new();
560 let mut known_safe = HashSet::new();
561 known_safe.insert("my_safe_func".to_string());
562
563 let error = check_statement_for_unsafe_calls(&stmt, &safety_context, &known_safe);
564 assert!(error.is_none(), "Known safe function should be allowed");
565 }
566
567 #[test]
568 fn test_unsafe_call_in_expression() {
569 let stmt = Statement::Assignment {
570 lhs: crate::parser::Expression::Variable("x".to_string()),
571 rhs: Expression::FunctionCall {
572 name: "unsafe_func".to_string(),
573 args: vec![],
574 },
575 location: SourceLocation {
576 file: "test.cpp".to_string(),
577 line: 15,
578 column: 5,
579 },
580 };
581
582 let safety_context = SafetyContext::new();
583 let known_safe = HashSet::new();
584
585 let error = check_statement_for_unsafe_calls(&stmt, &safety_context, &known_safe);
586 assert!(error.is_some());
587 let error_msg = error.unwrap();
588 assert!(error_msg.contains("unsafe_func"));
589 }
590}