1use crate::parser::{Expression, Function, Statement};
14use crate::parser::safety_annotations::SafetyMode;
15use std::collections::HashMap;
16
17#[derive(Debug, Clone)]
19pub struct SafeFnAssignmentCheck {
20 pub variable_name: String,
21 pub target_function: String,
22 pub is_valid: bool,
23 pub error_message: Option<String>,
24}
25
26pub fn is_safe_fn_type(type_name: &str) -> bool {
32 let normalized = type_name.replace(" ", "");
33 normalized.starts_with("rusty::SafeFn<") ||
34 normalized.starts_with("SafeFn<") ||
35 normalized.starts_with("rusty::SafeMemFn<") ||
36 normalized.starts_with("SafeMemFn<")
37}
38
39pub fn is_unsafe_fn_type(type_name: &str) -> bool {
41 let normalized = type_name.replace(" ", "");
42 normalized.starts_with("rusty::UnsafeFn<") ||
43 normalized.starts_with("UnsafeFn<") ||
44 normalized.starts_with("rusty::UnsafeMemFn<") ||
45 normalized.starts_with("UnsafeMemFn<")
46}
47
48pub fn is_raw_function_pointer_type(type_name: &str) -> bool {
51 type_name.contains("(*)") || type_name.contains("::*)")
55}
56
57pub fn is_safe_fn_call(callee_type: &str, method_name: &str) -> bool {
59 is_safe_fn_type(callee_type) && method_name == "operator()"
61}
62
63pub fn is_unsafe_fn_call_unsafe_method(callee_type: &str, method_name: &str) -> bool {
65 is_unsafe_fn_type(callee_type) && method_name == "call_unsafe"
67}
68
69pub fn is_member_fn_wrapper_type(type_name: &str) -> bool {
71 let normalized = type_name.replace(" ", "");
72 normalized.contains("MemFn<")
73}
74
75pub fn is_raw_member_function_pointer_type(type_name: &str) -> bool {
78 if is_safe_fn_type(type_name) || is_unsafe_fn_type(type_name) {
80 return false;
81 }
82 type_name.contains("::*)") || type_name.contains("::*)(")
84}
85
86pub fn check_function_pointer_safety(
97 function: &Function,
98 function_safety: SafetyMode,
99 known_safe_functions: &HashMap<String, SafetyMode>,
100) -> Vec<String> {
101 let mut errors = Vec::new();
102 let mut unsafe_depth = 0;
103
104 if function_safety != SafetyMode::Safe {
106 return errors;
107 }
108
109 for stmt in &function.body {
110 match stmt {
112 Statement::EnterUnsafe => {
113 unsafe_depth += 1;
114 continue;
115 }
116 Statement::ExitUnsafe => {
117 if unsafe_depth > 0 {
118 unsafe_depth -= 1;
119 }
120 continue;
121 }
122 _ => {}
123 }
124
125 let in_unsafe_scope = unsafe_depth > 0;
126
127 if let Some(error) = check_statement_for_function_pointer_safety(
128 stmt,
129 in_unsafe_scope,
130 known_safe_functions,
131 ) {
132 errors.push(format!("In function '{}': {}", function.name, error));
133 }
134 }
135
136 errors
137}
138
139fn check_statement_for_function_pointer_safety(
140 stmt: &Statement,
141 in_unsafe_scope: bool,
142 known_safe_functions: &HashMap<String, SafetyMode>,
143) -> Option<String> {
144 if in_unsafe_scope {
146 return None;
147 }
148
149 match stmt {
150 Statement::VariableDecl(var) => {
151 if is_safe_fn_type(&var.type_name) {
153 return None;
156 }
157
158 None
160 }
161
162 Statement::Assignment { lhs, rhs, location } => {
163 if let Expression::Variable(var_name) = lhs {
165 if let Some(error) = check_safe_fn_assignment_expr(
167 var_name,
168 rhs,
169 known_safe_functions,
170 location.line as usize,
171 ) {
172 return Some(error);
173 }
174 }
175 None
176 }
177
178 Statement::FunctionCall { name, args, location, .. } => {
179 if name.ends_with("::call_unsafe") || name.ends_with(".call_unsafe") {
185 return Some(format!(
186 "Call to UnsafeFn::call_unsafe() at line {} requires @unsafe context",
187 location.line as usize
188 ));
189 }
190
191 None
196 }
197
198 Statement::If { then_branch, else_branch, .. } => {
199 for stmt in then_branch {
201 if let Some(error) = check_statement_for_function_pointer_safety(
202 stmt, in_unsafe_scope, known_safe_functions
203 ) {
204 return Some(error);
205 }
206 }
207 if let Some(else_stmts) = else_branch {
208 for stmt in else_stmts {
209 if let Some(error) = check_statement_for_function_pointer_safety(
210 stmt, in_unsafe_scope, known_safe_functions
211 ) {
212 return Some(error);
213 }
214 }
215 }
216 None
217 }
218
219 Statement::Block(stmts) => {
220 for stmt in stmts {
221 if let Some(error) = check_statement_for_function_pointer_safety(
222 stmt, in_unsafe_scope, known_safe_functions
223 ) {
224 return Some(error);
225 }
226 }
227 None
228 }
229
230 _ => None,
231 }
232}
233
234fn check_safe_fn_assignment_expr(
236 _var_name: &str,
237 rhs: &Expression,
238 known_safe_functions: &HashMap<String, SafetyMode>,
239 line: usize,
240) -> Option<String> {
241 let func_name = match extract_function_from_address_of(rhs) {
243 Some(name) => name,
244 None => return None, };
246
247 match known_safe_functions.get(&func_name) {
249 Some(SafetyMode::Safe) => None, Some(SafetyMode::Unsafe) => {
251 Some(format!(
252 "Cannot assign @unsafe function '{}' to SafeFn at line {}. \
253 SafeFn can only hold pointers to @safe functions.",
254 func_name, line
255 ))
256 }
257 None => {
258 Some(format!(
260 "Cannot assign unannotated function '{}' to SafeFn at line {}. \
261 The target function must be marked @safe. \
262 (Unannotated functions are @unsafe by default)",
263 func_name, line
264 ))
265 }
266 }
267}
268
269fn extract_function_from_address_of(expr: &Expression) -> Option<String> {
271 match expr {
272 Expression::AddressOf(inner) => {
273 match inner.as_ref() {
274 Expression::Variable(name) => Some(name.clone()),
275 Expression::MemberAccess { object, field } => {
276 if let Expression::Variable(class_name) = object.as_ref() {
278 Some(format!("{}::{}", class_name, field))
279 } else {
280 None
281 }
282 }
283 _ => None,
284 }
285 }
286 Expression::Variable(name) => {
288 if name.contains("::") || name.starts_with(|c: char| c.is_uppercase()) {
291 Some(name.clone())
292 } else {
293 None
294 }
295 }
296 _ => None,
297 }
298}
299
300pub fn check_raw_function_pointer_call(
303 callee: &Expression,
304 callee_type: Option<&str>,
305 line: usize,
306) -> Option<String> {
307 if let Some(type_name) = callee_type {
309 if is_raw_function_pointer_type(type_name) {
310 return Some(format!(
311 "Call through raw function pointer at line {} requires @unsafe context. \
312 Consider using SafeFn<Sig> or UnsafeFn<Sig> wrapper types.",
313 line
314 ));
315 }
316 }
317
318 match callee {
320 Expression::Dereference(inner) => {
321 if let Expression::Variable(_) = inner.as_ref() {
323 return Some(format!(
324 "Call through dereferenced function pointer at line {} requires @unsafe context. \
325 Consider using SafeFn<Sig> or UnsafeFn<Sig> wrapper types.",
326 line
327 ));
328 }
329 }
330 _ => {}
331 }
332
333 None
334}
335
336#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_is_safe_fn_type() {
346 assert!(is_safe_fn_type("rusty::SafeFn<void(int)>"));
347 assert!(is_safe_fn_type("SafeFn<int(const char*)>"));
348 assert!(is_safe_fn_type("rusty::SafeMemFn<void (MyClass::*)(int)>"));
349 assert!(is_safe_fn_type("SafeMemFn<int (Widget::*)(double) const>"));
350
351 assert!(!is_safe_fn_type("rusty::UnsafeFn<void(int)>"));
352 assert!(!is_safe_fn_type("std::function<void(int)>"));
353 assert!(!is_safe_fn_type("void (*)(int)"));
354 }
355
356 #[test]
357 fn test_is_unsafe_fn_type() {
358 assert!(is_unsafe_fn_type("rusty::UnsafeFn<void(int)>"));
359 assert!(is_unsafe_fn_type("UnsafeFn<int(const char*)>"));
360 assert!(is_unsafe_fn_type("rusty::UnsafeMemFn<void (MyClass::*)(int)>"));
361
362 assert!(!is_unsafe_fn_type("rusty::SafeFn<void(int)>"));
363 assert!(!is_unsafe_fn_type("std::function<void(int)>"));
364 }
365
366 #[test]
367 fn test_is_raw_function_pointer_type() {
368 assert!(is_raw_function_pointer_type("void (*)(int)"));
369 assert!(is_raw_function_pointer_type("int (*)(const char*, ...)"));
370 assert!(is_raw_function_pointer_type("void (MyClass::*)(int)"));
371 assert!(is_raw_function_pointer_type("int (Widget::*)(double) const"));
372
373 assert!(!is_raw_function_pointer_type("rusty::SafeFn<void(int)>"));
374 assert!(!is_raw_function_pointer_type("std::function<void(int)>"));
375 assert!(!is_raw_function_pointer_type("void"));
376 }
377
378 #[test]
379 fn test_is_safe_fn_call() {
380 assert!(is_safe_fn_call("rusty::SafeFn<void(int)>", "operator()"));
381 assert!(is_safe_fn_call("SafeFn<int()>", "operator()"));
382
383 assert!(!is_safe_fn_call("rusty::SafeFn<void(int)>", "get"));
384 assert!(!is_safe_fn_call("rusty::UnsafeFn<void(int)>", "operator()"));
385 }
386
387 #[test]
388 fn test_is_unsafe_fn_call_unsafe_method() {
389 assert!(is_unsafe_fn_call_unsafe_method("rusty::UnsafeFn<void(int)>", "call_unsafe"));
390 assert!(is_unsafe_fn_call_unsafe_method("UnsafeFn<int()>", "call_unsafe"));
391
392 assert!(!is_unsafe_fn_call_unsafe_method("rusty::UnsafeFn<void(int)>", "get"));
393 assert!(!is_unsafe_fn_call_unsafe_method("rusty::SafeFn<void(int)>", "call_unsafe"));
394 }
395
396 #[test]
397 fn test_extract_function_from_address_of() {
398 let expr = Expression::AddressOf(Box::new(Expression::Variable("my_func".to_string())));
400 assert_eq!(extract_function_from_address_of(&expr), Some("my_func".to_string()));
401
402 let expr = Expression::AddressOf(Box::new(Expression::MemberAccess {
404 object: Box::new(Expression::Variable("MyClass".to_string())),
405 field: "method".to_string(),
406 }));
407 assert_eq!(extract_function_from_address_of(&expr), Some("MyClass::method".to_string()));
408
409 let expr = Expression::Variable("x".to_string());
411 assert_eq!(extract_function_from_address_of(&expr), None);
412 }
413
414 #[test]
415 fn test_check_safe_fn_assignment_with_safe_function() {
416 let mut known_safe: HashMap<String, SafetyMode> = HashMap::new();
417 known_safe.insert("safe_func".to_string(), SafetyMode::Safe);
418
419 let rhs = Expression::AddressOf(Box::new(Expression::Variable("safe_func".to_string())));
420
421 let result = check_safe_fn_assignment_expr("callback", &rhs, &known_safe, 10);
422 assert!(result.is_none(), "Assignment of @safe function to SafeFn should succeed");
423 }
424
425 #[test]
426 fn test_check_safe_fn_assignment_with_unsafe_function() {
427 let mut known_safe: HashMap<String, SafetyMode> = HashMap::new();
428 known_safe.insert("unsafe_func".to_string(), SafetyMode::Unsafe);
429
430 let rhs = Expression::AddressOf(Box::new(Expression::Variable("unsafe_func".to_string())));
431
432 let result = check_safe_fn_assignment_expr("callback", &rhs, &known_safe, 10);
433 assert!(result.is_some(), "Assignment of @unsafe function to SafeFn should fail");
434 assert!(result.unwrap().contains("@unsafe function"));
435 }
436
437 #[test]
438 fn test_check_safe_fn_assignment_with_unknown_function() {
439 let known_safe: HashMap<String, SafetyMode> = HashMap::new();
440
441 let rhs = Expression::AddressOf(Box::new(Expression::Variable("unknown_func".to_string())));
442
443 let result = check_safe_fn_assignment_expr("callback", &rhs, &known_safe, 10);
444 assert!(result.is_some(), "Assignment of unknown function to SafeFn should fail");
445 assert!(result.unwrap().contains("unannotated function"));
446 }
447
448 #[test]
450 fn test_is_member_fn_wrapper_type() {
451 assert!(is_member_fn_wrapper_type("rusty::SafeMemFn<void (MyClass::*)(int)>"));
452 assert!(is_member_fn_wrapper_type("SafeMemFn<int (Widget::*)(double) const>"));
453 assert!(is_member_fn_wrapper_type("rusty::UnsafeMemFn<void (MyClass::*)(int)>"));
454 assert!(is_member_fn_wrapper_type("UnsafeMemFn<int (Widget::*)()>"));
455
456 assert!(!is_member_fn_wrapper_type("rusty::SafeFn<void(int)>"));
457 assert!(!is_member_fn_wrapper_type("void (MyClass::*)(int)"));
458 }
459
460 #[test]
461 fn test_is_raw_member_function_pointer_type() {
462 assert!(is_raw_member_function_pointer_type("void (MyClass::*)(int)"));
463 assert!(is_raw_member_function_pointer_type("int (Widget::*)(double) const"));
464 assert!(is_raw_member_function_pointer_type("bool (Foo::*)()"));
465
466 assert!(!is_raw_member_function_pointer_type("void (*)(int)"));
467 assert!(!is_raw_member_function_pointer_type("rusty::SafeMemFn<void (MyClass::*)(int)>"));
468 }
469
470 #[test]
471 fn test_safe_mem_fn_call() {
472 assert!(is_safe_fn_call("rusty::SafeMemFn<void (MyClass::*)(int)>", "operator()"));
474 assert!(is_safe_fn_call("SafeMemFn<int (Widget::*)() const>", "operator()"));
475 }
476
477 #[test]
478 fn test_unsafe_mem_fn_call_unsafe() {
479 assert!(is_unsafe_fn_call_unsafe_method("rusty::UnsafeMemFn<void (MyClass::*)(int)>", "call_unsafe"));
481 assert!(is_unsafe_fn_call_unsafe_method("UnsafeMemFn<int (Widget::*)()>", "call_unsafe"));
482 }
483}