rusty_cpp/analysis/
unsafe_propagation.rs

1use crate::parser::{Function, Statement, Expression};
2use crate::parser::safety_annotations::{SafetyContext, SafetyMode};
3use crate::parser::external_annotations::ExternalAnnotations;
4use std::collections::HashSet;
5
6/// Check for unsafe propagation in safe functions
7///
8/// In safe code, the following require explicit @unsafe annotation:
9/// 1. Calling functions not marked as @safe
10/// 2. Using types/structs not marked as @safe
11/// 3. Any operation on unsafe types
12pub fn check_unsafe_propagation(
13    function: &Function,
14    safety_context: &SafetyContext,
15    known_safe_functions: &HashSet<String>,
16) -> Vec<String> {
17    check_unsafe_propagation_with_external(function, safety_context, known_safe_functions, None)
18}
19
20/// Check for unsafe propagation with external annotations support
21pub fn check_unsafe_propagation_with_external(
22    function: &Function,
23    safety_context: &SafetyContext,
24    known_safe_functions: &HashSet<String>,
25    external_annotations: Option<&ExternalAnnotations>,
26) -> Vec<String> {
27    let mut errors = Vec::new();
28    let mut unsafe_depth = 0;
29
30    // Collect callable parameters - parameters whose type is or contains a template type parameter
31    // e.g., for template<typename F> void foo(F&& write_fn), "write_fn" is a callable parameter
32    let callable_params = get_callable_parameters(&function.parameters, &function.template_parameters);
33
34    // Check each statement in the function
35    for stmt in &function.body {
36        // Track unsafe scope depth
37        match stmt {
38            Statement::EnterUnsafe => {
39                unsafe_depth += 1;
40                continue;
41            }
42            Statement::ExitUnsafe => {
43                if unsafe_depth > 0 {
44                    unsafe_depth -= 1;
45                }
46                continue;
47            }
48            _ => {}
49        }
50
51        // Skip checking if we're in an unsafe block
52        let in_unsafe_scope = unsafe_depth > 0;
53
54        if let Some(error) = check_statement_for_unsafe_calls_with_external(
55            stmt, safety_context, known_safe_functions, external_annotations,
56            &function.template_parameters, &callable_params, in_unsafe_scope
57        ) {
58            errors.push(format!("In function '{}': {}", function.name, error));
59        }
60    }
61
62    errors
63}
64
65/// Get list of parameter names that are callable (their type is/contains a template parameter)
66/// For example: template<typename F> void foo(F&& write_fn) -> returns ["write_fn"]
67fn get_callable_parameters(parameters: &[crate::parser::Variable], template_params: &[String]) -> HashSet<String> {
68    let mut callable_params = HashSet::new();
69
70    for param in parameters {
71        // Check if the parameter's type contains any template type parameter
72        // This handles: F, F&&, F&, const F&, std::function<...> where ... contains F, etc.
73        let type_name = &param.type_name;
74
75        for template_param in template_params {
76            // Check if the type contains the template parameter
77            // Handle cases like: F, F&&, F&, const F&, F *, etc.
78            if type_contains_template_param(type_name, template_param) {
79                callable_params.insert(param.name.clone());
80                break;
81            }
82        }
83    }
84
85    callable_params
86}
87
88/// Check if a type name contains a template parameter
89/// Handles: F, F&&, F&, const F&, F const&, etc.
90fn type_contains_template_param(type_name: &str, template_param: &str) -> bool {
91    // Simple word boundary check - the template param should appear as a whole word
92    // not as part of another identifier
93    let type_clean = type_name.replace("const", "").replace("&&", "").replace("&", "")
94                              .replace("*", "").replace(" ", "");
95
96    // Check for exact match or template param at word boundary
97    if type_clean == template_param {
98        return true;
99    }
100
101    // Check if template param appears as a word in the type
102    // e.g., "F" in "F&&" or "F &" or "const F&"
103    let words: Vec<&str> = type_name.split(|c: char| !c.is_alphanumeric() && c != '_')
104                                     .filter(|s| !s.is_empty())
105                                     .collect();
106    words.contains(&template_param.as_ref())
107}
108
109fn check_statement_for_unsafe_calls(
110    stmt: &Statement,
111    safety_context: &SafetyContext,
112    known_safe_functions: &HashSet<String>,
113) -> Option<String> {
114    check_statement_for_unsafe_calls_with_external(stmt, safety_context, known_safe_functions, None, &[], &HashSet::new(), false)
115}
116
117/// Check if a name looks like a template type parameter (including variadic pack parameters)
118/// This includes:
119/// - Exact matches: "T", "Args"
120/// - Pack patterns: "Args...", "Rest..."
121/// - Element types: "T&&", "Args&&" (forwarding references in packs)
122/// - Generic names: short uppercase-starting names that look like template params
123fn is_template_parameter_like(name: &str, template_params: &[String]) -> bool {
124    // Exact match
125    if template_params.contains(&name.to_string()) {
126        return true;
127    }
128
129    // Phase 1: Recognize pack-related patterns
130    // Pattern 1: Name ends with "..." (pack expansion)
131    if name.ends_with("...") {
132        let base_name = name.trim_end_matches("...").trim();
133        if template_params.contains(&base_name.to_string()) {
134            return true;
135        }
136    }
137
138    // Pattern 2: Name with && (forwarding reference, common in pack element types)
139    // e.g., "Args&&" where "Args" is a template parameter
140    if name.ends_with("&&") || name.ends_with("&") {
141        let base_name = name.trim_end_matches('&').trim();
142        if template_params.contains(&base_name.to_string()) {
143            return true;
144        }
145    }
146
147    // Pattern 3: Generic template-like names (short, uppercase start)
148    // This catches variations that the parser might produce
149    if name.len() <= 8 && name.len() > 0 {
150        if let Some(first_char) = name.chars().next() {
151            if first_char.is_uppercase() && name.chars().all(|c| c.is_alphanumeric() || c == '_') {
152                // Looks like a template parameter name
153                return true;
154            }
155        }
156    }
157
158    false
159}
160
161/// Process a list of statements while tracking unsafe depth, returning all errors found
162fn check_statements_with_unsafe_tracking(
163    statements: &[Statement],
164    safety_context: &SafetyContext,
165    known_safe_functions: &HashSet<String>,
166    external_annotations: Option<&ExternalAnnotations>,
167    template_params: &[String],
168    callable_params: &HashSet<String>,
169    initial_unsafe_depth: usize,
170) -> Vec<String> {
171    let mut errors = Vec::new();
172    let mut unsafe_depth = initial_unsafe_depth;
173
174    for stmt in statements {
175        // Track unsafe scope depth
176        match stmt {
177            Statement::EnterUnsafe => {
178                unsafe_depth += 1;
179                continue;
180            }
181            Statement::ExitUnsafe => {
182                if unsafe_depth > 0 {
183                    unsafe_depth -= 1;
184                }
185                continue;
186            }
187            _ => {}
188        }
189
190        let in_unsafe_scope = unsafe_depth > 0;
191
192        if let Some(error) = check_statement_for_unsafe_calls_with_external(
193            stmt, safety_context, known_safe_functions, external_annotations,
194            template_params, callable_params, in_unsafe_scope
195        ) {
196            errors.push(error);
197        }
198    }
199
200    errors
201}
202
203fn check_statement_for_unsafe_calls_with_external(
204    stmt: &Statement,
205    safety_context: &SafetyContext,
206    known_safe_functions: &HashSet<String>,
207    external_annotations: Option<&ExternalAnnotations>,
208    template_params: &[String],
209    callable_params: &HashSet<String>,
210    in_unsafe_scope: bool,
211) -> Option<String> {
212    use crate::parser::Statement;
213
214    // Skip all checks if we're in an unsafe block
215    if in_unsafe_scope {
216        return None;
217    }
218
219    match stmt {
220        Statement::FunctionCall { name, location, .. } => {
221            // Check if this is a template type parameter (not a real function call)
222            // Phase 1: Enhanced check for variadic pack parameters
223            if is_template_parameter_like(name, template_params) {
224                return None; // Template type parameters are safe to use
225            }
226
227            // Special case: "unknown" function calls in template context are likely template type constructors
228            // e.g., T(), T(x), etc. where the parser couldn't determine the name
229            if !template_params.is_empty() && name == "unknown" {
230                return None; // Allow unknown function calls in template context
231            }
232
233            // Special case: Lambda operator() calls
234            // Lambdas defined in @safe context have already been checked for safety
235            // Their operator() is safe to call
236            if name == "operator()" || name.contains("operator()") {
237                return None; // Lambda calls are safe - their body was already checked
238            }
239
240            // Special case: Callable template parameters
241            // e.g., template<typename F> void foo(F&& write_fn) { write_fn(42); }
242            // Calling write_fn is safe because it's a callable passed by the caller
243            // Note: In class methods, the name might be prefixed with class name (e.g., "Class::handler")
244            if callable_params.contains(name) {
245                return None; // Callable parameters are safe to invoke
246            }
247            // Also check for class-prefixed version (e.g., "Class::handler" -> check "handler")
248            if let Some(simple_name) = name.rsplit("::").next() {
249                if callable_params.contains(simple_name) {
250                    return None; // Callable parameters are safe to invoke
251                }
252            }
253
254            // Get the safety mode of the called function
255            let called_safety = get_called_function_safety(name, safety_context, known_safe_functions, external_annotations);
256
257            match called_safety {
258                SafetyMode::Safe => {
259                    // OK: safe can call safe
260                }
261                SafetyMode::Unsafe => {
262                    // ERROR: safe cannot call unsafe/unannotated functions directly
263                    // Must wrap in @unsafe { } block
264                    return Some(format!(
265                        "Calling non-safe function '{}' at line {} requires @unsafe {{ }} block",
266                        name, location.line
267                    ));
268                }
269            }
270        }
271        Statement::Assignment { rhs, location, .. } => {
272            // Check for function calls in the right-hand side
273            if let Some(unsafe_func) = find_unsafe_function_call_with_external(rhs, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
274                return Some(format!(
275                    "Calling unsafe function '{}' at line {} requires unsafe context",
276                    unsafe_func, location.line
277                ));
278            }
279        }
280        Statement::Return(Some(expr)) => {
281            // Check for function calls in return expression
282            if let Some(unsafe_func) = find_unsafe_function_call_with_external(expr, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
283                return Some(format!(
284                    "Calling unsafe function '{}' in return statement requires unsafe context",
285                    unsafe_func
286                ));
287            }
288        }
289        Statement::If { condition, then_branch, else_branch, location } => {
290            // Check condition
291            if let Some(unsafe_func) = find_unsafe_function_call_with_external(condition, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
292                return Some(format!(
293                    "Calling unsafe function '{}' in condition at line {} requires unsafe context",
294                    unsafe_func, location.line
295                ));
296            }
297
298            // Recursively check branches with proper unsafe depth tracking
299            // Start with unsafe_depth=0 since in_unsafe_scope=false here (we return early if true)
300            let then_errors = check_statements_with_unsafe_tracking(
301                then_branch, safety_context, known_safe_functions, external_annotations,
302                template_params, callable_params, 0
303            );
304            if !then_errors.is_empty() {
305                return Some(then_errors.into_iter().next().unwrap());
306            }
307
308            if let Some(else_stmts) = else_branch {
309                let else_errors = check_statements_with_unsafe_tracking(
310                    else_stmts, safety_context, known_safe_functions, external_annotations,
311                    template_params, callable_params, 0
312                );
313                if !else_errors.is_empty() {
314                    return Some(else_errors.into_iter().next().unwrap());
315                }
316            }
317        }
318        Statement::Block(statements) => {
319            // Check all statements in the block with proper unsafe depth tracking
320            let block_errors = check_statements_with_unsafe_tracking(
321                statements, safety_context, known_safe_functions, external_annotations,
322                template_params, callable_params, 0
323            );
324            if !block_errors.is_empty() {
325                return Some(block_errors.into_iter().next().unwrap());
326            }
327        }
328        _ => {}
329    }
330
331    None
332}
333
334fn find_unsafe_function_call(
335    expr: &Expression,
336    safety_context: &SafetyContext,
337    known_safe_functions: &HashSet<String>,
338) -> Option<String> {
339    find_unsafe_function_call_with_external(expr, safety_context, known_safe_functions, None, &[], &HashSet::new())
340}
341
342fn find_unsafe_function_call_with_external(
343    expr: &Expression,
344    safety_context: &SafetyContext,
345    known_safe_functions: &HashSet<String>,
346    external_annotations: Option<&ExternalAnnotations>,
347    template_params: &[String],
348    callable_params: &HashSet<String>,
349) -> Option<String> {
350    use crate::parser::Expression;
351
352    match expr {
353        Expression::FunctionCall { name, args } => {
354            // Check if this is a template type parameter (not a real function call)
355            // Phase 1: Enhanced check for variadic pack parameters
356            if is_template_parameter_like(name, template_params) {
357                // Template type parameters are safe to use (e.g., T x = ...)
358                // Just check the arguments
359                for arg in args {
360                    if let Some(unsafe_func) = find_unsafe_function_call_with_external(arg, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
361                        return Some(unsafe_func);
362                    }
363                }
364                return None;
365            }
366
367            // Special case: "unknown" function calls in template context are likely template type constructors
368            // e.g., T(), T(x), etc. where the parser couldn't determine the name
369            if !template_params.is_empty() && name == "unknown" {
370                // Just check the arguments
371                for arg in args {
372                    if let Some(unsafe_func) = find_unsafe_function_call_with_external(arg, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
373                        return Some(unsafe_func);
374                    }
375                }
376                return None; // Allow unknown function calls in template context
377            }
378
379            // Special case: Lambda operator() calls
380            // Lambdas defined in @safe context have already been checked for safety
381            if name == "operator()" || name.contains("operator()") {
382                // Just check the arguments
383                for arg in args {
384                    if let Some(unsafe_func) = find_unsafe_function_call_with_external(arg, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
385                        return Some(unsafe_func);
386                    }
387                }
388                return None; // Lambda calls are safe - their body was already checked
389            }
390
391            // Special case: Callable template parameters
392            // e.g., template<typename F> void foo(F&& write_fn) { write_fn(42); }
393            // Note: In class methods, the name might be prefixed with class name (e.g., "Class::handler")
394            let is_callable_param = callable_params.contains(name) ||
395                name.rsplit("::").next().map(|s| callable_params.contains(s)).unwrap_or(false);
396            if is_callable_param {
397                // Just check the arguments
398                for arg in args {
399                    if let Some(unsafe_func) = find_unsafe_function_call_with_external(arg, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
400                        return Some(unsafe_func);
401                    }
402                }
403                return None; // Callable parameters are safe to invoke
404            }
405
406            // Get the safety mode of the called function
407            let called_safety = get_called_function_safety(name, safety_context, known_safe_functions, external_annotations);
408
409            // Apply the corrected rules:
410            // - Safe functions can call safe functions
411            // - Safe functions can call unsafe functions (they're explicitly marked)
412            // - Safe functions CANNOT call undeclared functions
413            match called_safety {
414                SafetyMode::Safe => {
415                    // OK: safe can call safe
416                }
417                SafetyMode::Unsafe => {
418                    // Error: safe function cannot call unsafe function directly
419                    return Some(format!("{} (non-safe - use @unsafe block)", name));
420                }
421            }
422
423            // Check arguments for nested unsafe calls
424            for arg in args {
425                if let Some(unsafe_func) = find_unsafe_function_call_with_external(arg, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
426                    return Some(unsafe_func);
427                }
428            }
429        }
430        Expression::BinaryOp { left, right, .. } => {
431            // Check both sides
432            if let Some(unsafe_func) = find_unsafe_function_call_with_external(left, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
433                return Some(unsafe_func);
434            }
435            if let Some(unsafe_func) = find_unsafe_function_call_with_external(right, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
436                return Some(unsafe_func);
437            }
438        }
439        Expression::Move { inner, .. } | Expression::Dereference(inner) | Expression::AddressOf(inner) => {
440            // Check inner expression
441            if let Some(unsafe_func) = find_unsafe_function_call_with_external(inner, safety_context, known_safe_functions, external_annotations, template_params, callable_params) {
442                return Some(unsafe_func);
443            }
444        }
445        _ => {}
446    }
447
448    None
449}
450
451fn is_function_safe(
452    func_name: &str,
453    safety_context: &SafetyContext,
454    known_safe_functions: &HashSet<String>,
455) -> bool {
456    is_function_safe_with_external(func_name, safety_context, known_safe_functions, None)
457}
458
459/// Get the safety mode of a called function
460fn get_called_function_safety(
461    func_name: &str,
462    safety_context: &SafetyContext,
463    known_safe_functions: &HashSet<String>,
464    external_annotations: Option<&ExternalAnnotations>,
465) -> SafetyMode {
466    // First check if we know about this function in our context
467    let local_safety = safety_context.get_function_safety(func_name);
468    if local_safety != SafetyMode::Unsafe {
469        return local_safety;
470    }
471
472    // Check if it's in our known safe functions set
473    if known_safe_functions.contains(func_name) {
474        return SafetyMode::Safe;
475    }
476
477    // Check external annotations if provided
478    if let Some(annotations) = external_annotations {
479        if let Some(is_safe) = annotations.is_function_safe(func_name) {
480            return if is_safe { SafetyMode::Safe } else { SafetyMode::Unsafe };
481        }
482    }
483
484    // Default to unsafe - all unannotated functions are unsafe
485    SafetyMode::Unsafe
486}
487
488fn is_function_safe_with_external(
489    func_name: &str,
490    safety_context: &SafetyContext,
491    known_safe_functions: &HashSet<String>,
492    external_annotations: Option<&ExternalAnnotations>,
493) -> bool {
494    get_called_function_safety(func_name, safety_context, known_safe_functions, external_annotations) == SafetyMode::Safe
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use crate::parser::{Statement, Expression, SourceLocation};
501    
502    #[test]
503    fn test_detect_unsafe_function_call() {
504        let stmt = Statement::FunctionCall {
505            name: "unknown_func".to_string(),
506            args: vec![],
507            location: SourceLocation {
508                file: "test.cpp".to_string(),
509                line: 10,
510                column: 5,
511            },
512        };
513        
514        let safety_context = SafetyContext::new();
515        let known_safe = HashSet::new();
516        
517        let error = check_statement_for_unsafe_calls(&stmt, &safety_context, &known_safe);
518        assert!(error.is_some());
519        let error_msg = error.unwrap();
520        assert!(error_msg.contains("unknown_func"));
521        assert!(error_msg.contains("unsafe"));
522    }
523    
524    #[test]
525    fn test_stl_functions_require_unsafe() {
526        // With the new two-state model, ALL non-safe functions (including STL) require @unsafe blocks
527        let stmt = Statement::FunctionCall {
528            name: "std::move".to_string(),
529            args: vec![Expression::Variable("x".to_string())],
530            location: SourceLocation {
531                file: "test.cpp".to_string(),
532                line: 10,
533                column: 5,
534            },
535        };
536
537        let safety_context = SafetyContext::new();
538        let known_safe = HashSet::new();
539
540        let error = check_statement_for_unsafe_calls(&stmt, &safety_context, &known_safe);
541        assert!(error.is_some(), "std::move should require @unsafe block in safe code");
542        let error_msg = error.unwrap();
543        assert!(error_msg.contains("std::move"));
544        assert!(error_msg.contains("@unsafe"));
545    }
546    
547    #[test]
548    fn test_known_safe_function() {
549        let stmt = Statement::FunctionCall {
550            name: "my_safe_func".to_string(),
551            args: vec![],
552            location: SourceLocation {
553                file: "test.cpp".to_string(),
554                line: 10,
555                column: 5,
556            },
557        };
558        
559        let safety_context = SafetyContext::new();
560        let mut known_safe = HashSet::new();
561        known_safe.insert("my_safe_func".to_string());
562        
563        let error = check_statement_for_unsafe_calls(&stmt, &safety_context, &known_safe);
564        assert!(error.is_none(), "Known safe function should be allowed");
565    }
566    
567    #[test]
568    fn test_unsafe_call_in_expression() {
569        let stmt = Statement::Assignment {
570            lhs: crate::parser::Expression::Variable("x".to_string()),
571            rhs: Expression::FunctionCall {
572                name: "unsafe_func".to_string(),
573                args: vec![],
574            },
575            location: SourceLocation {
576                file: "test.cpp".to_string(),
577                line: 15,
578                column: 5,
579            },
580        };
581        
582        let safety_context = SafetyContext::new();
583        let known_safe = HashSet::new();
584        
585        let error = check_statement_for_unsafe_calls(&stmt, &safety_context, &known_safe);
586        assert!(error.is_some());
587        let error_msg = error.unwrap();
588        assert!(error_msg.contains("unsafe_func"));
589    }
590}