rusty_cpp/analysis/
function_pointer_safety.rs

1//! Function Pointer Safety Analysis
2//!
3//! This module implements safety checking for function pointers using type-level
4//! encoding with `SafeFn<Sig>` and `UnsafeFn<Sig>` wrapper types.
5//!
6//! Key concepts:
7//! - `SafeFn<Ret(Args...)>` - holds a pointer to a @safe function, can be called safely
8//! - `UnsafeFn<Ret(Args...)>` - holds any function pointer, requires @unsafe to call
9//! - Raw function pointers require @unsafe to call
10//!
11//! See `docs/FUNCTION_POINTER_SAFETY_PLAN.md` for the full design.
12
13use crate::parser::{Expression, Function, Statement};
14use crate::parser::safety_annotations::SafetyMode;
15use std::collections::HashMap;
16
17/// Result of checking a SafeFn assignment
18#[derive(Debug, Clone)]
19pub struct SafeFnAssignmentCheck {
20    pub variable_name: String,
21    pub target_function: String,
22    pub is_valid: bool,
23    pub error_message: Option<String>,
24}
25
26// ============================================================================
27// Type Detection
28// ============================================================================
29
30/// Check if a type is a SafeFn wrapper type
31pub fn is_safe_fn_type(type_name: &str) -> bool {
32    let normalized = type_name.replace(" ", "");
33    normalized.starts_with("rusty::SafeFn<") ||
34    normalized.starts_with("SafeFn<") ||
35    normalized.starts_with("rusty::SafeMemFn<") ||
36    normalized.starts_with("SafeMemFn<")
37}
38
39/// Check if a type is an UnsafeFn wrapper type
40pub fn is_unsafe_fn_type(type_name: &str) -> bool {
41    let normalized = type_name.replace(" ", "");
42    normalized.starts_with("rusty::UnsafeFn<") ||
43    normalized.starts_with("UnsafeFn<") ||
44    normalized.starts_with("rusty::UnsafeMemFn<") ||
45    normalized.starts_with("UnsafeMemFn<")
46}
47
48/// Check if a type is a raw function pointer
49/// Matches patterns like: void (*)(int), int (*)(const char*, ...), void (MyClass::*)(int)
50pub fn is_raw_function_pointer_type(type_name: &str) -> bool {
51    // Check for function pointer patterns
52    // void (*)(int) - free function pointer
53    // int (MyClass::*)(int) - member function pointer
54    type_name.contains("(*)") || type_name.contains("::*)")
55}
56
57/// Check if a function call is calling through a SafeFn or SafeMemFn wrapper
58pub fn is_safe_fn_call(callee_type: &str, method_name: &str) -> bool {
59    // SafeFn<Sig>::operator() or SafeMemFn<Sig>::operator() - safe to call
60    is_safe_fn_type(callee_type) && method_name == "operator()"
61}
62
63/// Check if a function call is calling through an UnsafeFn or UnsafeMemFn wrapper
64pub fn is_unsafe_fn_call_unsafe_method(callee_type: &str, method_name: &str) -> bool {
65    // UnsafeFn<Sig>::call_unsafe or UnsafeMemFn<Sig>::call_unsafe - requires @unsafe
66    is_unsafe_fn_type(callee_type) && method_name == "call_unsafe"
67}
68
69/// Check if a type is a member function pointer wrapper (safe or unsafe)
70pub fn is_member_fn_wrapper_type(type_name: &str) -> bool {
71    let normalized = type_name.replace(" ", "");
72    normalized.contains("MemFn<")
73}
74
75/// Check if a type is a raw member function pointer
76/// Matches patterns like: void (MyClass::*)(int), int (Widget::*)(double) const
77pub fn is_raw_member_function_pointer_type(type_name: &str) -> bool {
78    // Exclude wrapper types first
79    if is_safe_fn_type(type_name) || is_unsafe_fn_type(type_name) {
80        return false;
81    }
82    // Member function pointer pattern: Ret (Class::*)(Args...)
83    type_name.contains("::*)") || type_name.contains("::*)(")
84}
85
86// ============================================================================
87// Safety Checking
88// ============================================================================
89
90/// Check function pointer safety in a parsed function
91///
92/// This checks:
93/// 1. SafeFn assignments have @safe targets
94/// 2. Raw function pointer calls require @unsafe
95/// 3. UnsafeFn::call_unsafe() requires @unsafe
96pub fn check_function_pointer_safety(
97    function: &Function,
98    function_safety: SafetyMode,
99    known_safe_functions: &HashMap<String, SafetyMode>,
100) -> Vec<String> {
101    let mut errors = Vec::new();
102    let mut unsafe_depth = 0;
103
104    // Only check in @safe functions
105    if function_safety != SafetyMode::Safe {
106        return errors;
107    }
108
109    for stmt in &function.body {
110        // Track unsafe scope
111        match stmt {
112            Statement::EnterUnsafe => {
113                unsafe_depth += 1;
114                continue;
115            }
116            Statement::ExitUnsafe => {
117                if unsafe_depth > 0 {
118                    unsafe_depth -= 1;
119                }
120                continue;
121            }
122            _ => {}
123        }
124
125        let in_unsafe_scope = unsafe_depth > 0;
126
127        if let Some(error) = check_statement_for_function_pointer_safety(
128            stmt,
129            in_unsafe_scope,
130            known_safe_functions,
131        ) {
132            errors.push(format!("In function '{}': {}", function.name, error));
133        }
134    }
135
136    errors
137}
138
139fn check_statement_for_function_pointer_safety(
140    stmt: &Statement,
141    in_unsafe_scope: bool,
142    known_safe_functions: &HashMap<String, SafetyMode>,
143) -> Option<String> {
144    // Skip checks in unsafe scope
145    if in_unsafe_scope {
146        return None;
147    }
148
149    match stmt {
150        Statement::VariableDecl(var) => {
151            // Check if declaring a SafeFn and initializing from a function address
152            if is_safe_fn_type(&var.type_name) {
153                // We need to check the initializer, but Variable doesn't have initializer info
154                // This check happens at the call site in check_safe_fn_assignment
155                return None;
156            }
157
158            // Raw function pointer declarations are OK, but calling is checked elsewhere
159            None
160        }
161
162        Statement::Assignment { lhs, rhs, location } => {
163            // Check if assigning to a SafeFn variable
164            if let Expression::Variable(var_name) = lhs {
165                // Check if rhs is a function address being assigned to SafeFn
166                if let Some(error) = check_safe_fn_assignment_expr(
167                    var_name,
168                    rhs,
169                    known_safe_functions,
170                    location.line as usize,
171                ) {
172                    return Some(error);
173                }
174            }
175            None
176        }
177
178        Statement::FunctionCall { name, args, location, .. } => {
179            // Check for raw function pointer calls
180            // A call through a raw function pointer looks like: fp(args)
181            // where fp is a variable of function pointer type
182
183            // Check for UnsafeFn::call_unsafe() calls
184            if name.ends_with("::call_unsafe") || name.ends_with(".call_unsafe") {
185                return Some(format!(
186                    "Call to UnsafeFn::call_unsafe() at line {} requires @unsafe context",
187                    location.line as usize
188                ));
189            }
190
191            // Check for calls through raw function pointers
192            // This is detected by checking if 'name' is a variable reference, not a function name
193            // We can't easily detect this without type info, so we'll check for common patterns
194
195            None
196        }
197
198        Statement::If { then_branch, else_branch, .. } => {
199            // Recursively check branches
200            for stmt in then_branch {
201                if let Some(error) = check_statement_for_function_pointer_safety(
202                    stmt, in_unsafe_scope, known_safe_functions
203                ) {
204                    return Some(error);
205                }
206            }
207            if let Some(else_stmts) = else_branch {
208                for stmt in else_stmts {
209                    if let Some(error) = check_statement_for_function_pointer_safety(
210                        stmt, in_unsafe_scope, known_safe_functions
211                    ) {
212                        return Some(error);
213                    }
214                }
215            }
216            None
217        }
218
219        Statement::Block(stmts) => {
220            for stmt in stmts {
221                if let Some(error) = check_statement_for_function_pointer_safety(
222                    stmt, in_unsafe_scope, known_safe_functions
223                ) {
224                    return Some(error);
225                }
226            }
227            None
228        }
229
230        _ => None,
231    }
232}
233
234/// Check if an expression assigned to a SafeFn variable is valid
235fn check_safe_fn_assignment_expr(
236    _var_name: &str,
237    rhs: &Expression,
238    known_safe_functions: &HashMap<String, SafetyMode>,
239    line: usize,
240) -> Option<String> {
241    // Extract function name from address-of expression
242    let func_name = match extract_function_from_address_of(rhs) {
243        Some(name) => name,
244        None => return None, // Not an address-of expression, skip
245    };
246
247    // Check if the function is known to be @safe
248    match known_safe_functions.get(&func_name) {
249        Some(SafetyMode::Safe) => None, // OK
250        Some(SafetyMode::Unsafe) => {
251            Some(format!(
252                "Cannot assign @unsafe function '{}' to SafeFn at line {}. \
253                 SafeFn can only hold pointers to @safe functions.",
254                func_name, line
255            ))
256        }
257        None => {
258            // Unknown functions are treated as @unsafe by default (two-state model)
259            Some(format!(
260                "Cannot assign unannotated function '{}' to SafeFn at line {}. \
261                 The target function must be marked @safe. \
262                 (Unannotated functions are @unsafe by default)",
263                func_name, line
264            ))
265        }
266    }
267}
268
269/// Extract function name from an address-of expression
270fn extract_function_from_address_of(expr: &Expression) -> Option<String> {
271    match expr {
272        Expression::AddressOf(inner) => {
273            match inner.as_ref() {
274                Expression::Variable(name) => Some(name.clone()),
275                Expression::MemberAccess { object, field } => {
276                    // &ClassName::method
277                    if let Expression::Variable(class_name) = object.as_ref() {
278                        Some(format!("{}::{}", class_name, field))
279                    } else {
280                        None
281                    }
282                }
283                _ => None,
284            }
285        }
286        // Direct function name (without &) - some compilers allow this
287        Expression::Variable(name) => {
288            // Check if it looks like a function name (not a variable)
289            // This is a heuristic
290            if name.contains("::") || name.starts_with(|c: char| c.is_uppercase()) {
291                Some(name.clone())
292            } else {
293                None
294            }
295        }
296        _ => None,
297    }
298}
299
300/// Check if a function call expression is through a raw function pointer
301/// Returns Some(error) if the call requires @unsafe
302pub fn check_raw_function_pointer_call(
303    callee: &Expression,
304    callee_type: Option<&str>,
305    line: usize,
306) -> Option<String> {
307    // If we have type info and it's a raw function pointer, flag it
308    if let Some(type_name) = callee_type {
309        if is_raw_function_pointer_type(type_name) {
310            return Some(format!(
311                "Call through raw function pointer at line {} requires @unsafe context. \
312                 Consider using SafeFn<Sig> or UnsafeFn<Sig> wrapper types.",
313                line
314            ));
315        }
316    }
317
318    // Check expression patterns that indicate a function pointer call
319    match callee {
320        Expression::Dereference(inner) => {
321            // (*fp)(args) - explicit dereference of function pointer
322            if let Expression::Variable(_) = inner.as_ref() {
323                return Some(format!(
324                    "Call through dereferenced function pointer at line {} requires @unsafe context. \
325                     Consider using SafeFn<Sig> or UnsafeFn<Sig> wrapper types.",
326                    line
327                ));
328            }
329        }
330        _ => {}
331    }
332
333    None
334}
335
336// ============================================================================
337// Tests
338// ============================================================================
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn test_is_safe_fn_type() {
346        assert!(is_safe_fn_type("rusty::SafeFn<void(int)>"));
347        assert!(is_safe_fn_type("SafeFn<int(const char*)>"));
348        assert!(is_safe_fn_type("rusty::SafeMemFn<void (MyClass::*)(int)>"));
349        assert!(is_safe_fn_type("SafeMemFn<int (Widget::*)(double) const>"));
350
351        assert!(!is_safe_fn_type("rusty::UnsafeFn<void(int)>"));
352        assert!(!is_safe_fn_type("std::function<void(int)>"));
353        assert!(!is_safe_fn_type("void (*)(int)"));
354    }
355
356    #[test]
357    fn test_is_unsafe_fn_type() {
358        assert!(is_unsafe_fn_type("rusty::UnsafeFn<void(int)>"));
359        assert!(is_unsafe_fn_type("UnsafeFn<int(const char*)>"));
360        assert!(is_unsafe_fn_type("rusty::UnsafeMemFn<void (MyClass::*)(int)>"));
361
362        assert!(!is_unsafe_fn_type("rusty::SafeFn<void(int)>"));
363        assert!(!is_unsafe_fn_type("std::function<void(int)>"));
364    }
365
366    #[test]
367    fn test_is_raw_function_pointer_type() {
368        assert!(is_raw_function_pointer_type("void (*)(int)"));
369        assert!(is_raw_function_pointer_type("int (*)(const char*, ...)"));
370        assert!(is_raw_function_pointer_type("void (MyClass::*)(int)"));
371        assert!(is_raw_function_pointer_type("int (Widget::*)(double) const"));
372
373        assert!(!is_raw_function_pointer_type("rusty::SafeFn<void(int)>"));
374        assert!(!is_raw_function_pointer_type("std::function<void(int)>"));
375        assert!(!is_raw_function_pointer_type("void"));
376    }
377
378    #[test]
379    fn test_is_safe_fn_call() {
380        assert!(is_safe_fn_call("rusty::SafeFn<void(int)>", "operator()"));
381        assert!(is_safe_fn_call("SafeFn<int()>", "operator()"));
382
383        assert!(!is_safe_fn_call("rusty::SafeFn<void(int)>", "get"));
384        assert!(!is_safe_fn_call("rusty::UnsafeFn<void(int)>", "operator()"));
385    }
386
387    #[test]
388    fn test_is_unsafe_fn_call_unsafe_method() {
389        assert!(is_unsafe_fn_call_unsafe_method("rusty::UnsafeFn<void(int)>", "call_unsafe"));
390        assert!(is_unsafe_fn_call_unsafe_method("UnsafeFn<int()>", "call_unsafe"));
391
392        assert!(!is_unsafe_fn_call_unsafe_method("rusty::UnsafeFn<void(int)>", "get"));
393        assert!(!is_unsafe_fn_call_unsafe_method("rusty::SafeFn<void(int)>", "call_unsafe"));
394    }
395
396    #[test]
397    fn test_extract_function_from_address_of() {
398        // &function_name
399        let expr = Expression::AddressOf(Box::new(Expression::Variable("my_func".to_string())));
400        assert_eq!(extract_function_from_address_of(&expr), Some("my_func".to_string()));
401
402        // &ClassName::method
403        let expr = Expression::AddressOf(Box::new(Expression::MemberAccess {
404            object: Box::new(Expression::Variable("MyClass".to_string())),
405            field: "method".to_string(),
406        }));
407        assert_eq!(extract_function_from_address_of(&expr), Some("MyClass::method".to_string()));
408
409        // Non-address-of expression
410        let expr = Expression::Variable("x".to_string());
411        assert_eq!(extract_function_from_address_of(&expr), None);
412    }
413
414    #[test]
415    fn test_check_safe_fn_assignment_with_safe_function() {
416        let mut known_safe: HashMap<String, SafetyMode> = HashMap::new();
417        known_safe.insert("safe_func".to_string(), SafetyMode::Safe);
418
419        let rhs = Expression::AddressOf(Box::new(Expression::Variable("safe_func".to_string())));
420
421        let result = check_safe_fn_assignment_expr("callback", &rhs, &known_safe, 10);
422        assert!(result.is_none(), "Assignment of @safe function to SafeFn should succeed");
423    }
424
425    #[test]
426    fn test_check_safe_fn_assignment_with_unsafe_function() {
427        let mut known_safe: HashMap<String, SafetyMode> = HashMap::new();
428        known_safe.insert("unsafe_func".to_string(), SafetyMode::Unsafe);
429
430        let rhs = Expression::AddressOf(Box::new(Expression::Variable("unsafe_func".to_string())));
431
432        let result = check_safe_fn_assignment_expr("callback", &rhs, &known_safe, 10);
433        assert!(result.is_some(), "Assignment of @unsafe function to SafeFn should fail");
434        assert!(result.unwrap().contains("@unsafe function"));
435    }
436
437    #[test]
438    fn test_check_safe_fn_assignment_with_unknown_function() {
439        let known_safe: HashMap<String, SafetyMode> = HashMap::new();
440
441        let rhs = Expression::AddressOf(Box::new(Expression::Variable("unknown_func".to_string())));
442
443        let result = check_safe_fn_assignment_expr("callback", &rhs, &known_safe, 10);
444        assert!(result.is_some(), "Assignment of unknown function to SafeFn should fail");
445        assert!(result.unwrap().contains("unannotated function"));
446    }
447
448    // Member function pointer tests
449    #[test]
450    fn test_is_member_fn_wrapper_type() {
451        assert!(is_member_fn_wrapper_type("rusty::SafeMemFn<void (MyClass::*)(int)>"));
452        assert!(is_member_fn_wrapper_type("SafeMemFn<int (Widget::*)(double) const>"));
453        assert!(is_member_fn_wrapper_type("rusty::UnsafeMemFn<void (MyClass::*)(int)>"));
454        assert!(is_member_fn_wrapper_type("UnsafeMemFn<int (Widget::*)()>"));
455
456        assert!(!is_member_fn_wrapper_type("rusty::SafeFn<void(int)>"));
457        assert!(!is_member_fn_wrapper_type("void (MyClass::*)(int)"));
458    }
459
460    #[test]
461    fn test_is_raw_member_function_pointer_type() {
462        assert!(is_raw_member_function_pointer_type("void (MyClass::*)(int)"));
463        assert!(is_raw_member_function_pointer_type("int (Widget::*)(double) const"));
464        assert!(is_raw_member_function_pointer_type("bool (Foo::*)()"));
465
466        assert!(!is_raw_member_function_pointer_type("void (*)(int)"));
467        assert!(!is_raw_member_function_pointer_type("rusty::SafeMemFn<void (MyClass::*)(int)>"));
468    }
469
470    #[test]
471    fn test_safe_mem_fn_call() {
472        // SafeMemFn::operator() should be detected as safe
473        assert!(is_safe_fn_call("rusty::SafeMemFn<void (MyClass::*)(int)>", "operator()"));
474        assert!(is_safe_fn_call("SafeMemFn<int (Widget::*)() const>", "operator()"));
475    }
476
477    #[test]
478    fn test_unsafe_mem_fn_call_unsafe() {
479        // UnsafeMemFn::call_unsafe should be detected
480        assert!(is_unsafe_fn_call_unsafe_method("rusty::UnsafeMemFn<void (MyClass::*)(int)>", "call_unsafe"));
481        assert!(is_unsafe_fn_call_unsafe_method("UnsafeMemFn<int (Widget::*)()>", "call_unsafe"));
482    }
483}