1use std::collections::{HashMap, HashSet};
19use crate::parser::ast_visitor::Class;
20use crate::parser::safety_annotations::SafetyMode;
21use crate::debug_println;
22
23pub fn validate_interface(class: &Class) -> Vec<String> {
25 let mut errors = Vec::new();
26
27 if !class.is_interface {
28 return errors;
29 }
30
31 debug_println!("INHERITANCE: Validating @interface '{}'", class.name);
32
33 let non_static_members: Vec<_> = class.members.iter()
35 .filter(|m| !m.is_static)
36 .collect();
37
38 if !non_static_members.is_empty() {
39 let member_names: Vec<_> = non_static_members.iter()
40 .map(|m| m.name.as_str())
41 .collect();
42 errors.push(format!(
43 "@interface '{}' cannot have data members: {:?}",
44 class.name, member_names
45 ));
46 }
47
48 if !class.all_methods_pure_virtual {
50 errors.push(format!(
51 "@interface '{}' must have all pure virtual methods (= 0)",
52 class.name
53 ));
54 }
55
56 if !class.has_virtual_destructor {
63 errors.push(format!(
64 "@interface '{}' must have a virtual destructor (e.g., virtual ~{}() = default;)",
65 class.name, class.name
66 ));
67 } else if !class.destructor_is_defaulted {
68 errors.push(format!(
69 "@interface '{}' virtual destructor must be defaulted (use virtual ~{}() = default;)",
70 class.name, class.name
71 ));
72 }
73
74 if class.has_non_virtual_methods {
76 errors.push(format!(
77 "@interface '{}' cannot have non-virtual methods",
78 class.name
79 ));
80 }
81
82 errors
83}
84
85pub fn validate_interface_inheritance(
87 class: &Class,
88 interfaces: &HashSet<String>,
89) -> Vec<String> {
90 let mut errors = Vec::new();
91
92 if !class.is_interface {
93 return errors;
94 }
95
96 for base in &class.base_classes {
97 let base_name = strip_template_params(base);
99
100 if !is_interface_base(&base_name, interfaces, &class.name) {
102 errors.push(format!(
103 "@interface '{}' can only inherit from other @interface classes, not '{}'",
104 class.name, base
105 ));
106 }
107 }
108
109 errors
110}
111
112pub fn check_safe_inheritance(
114 class: &Class,
115 interfaces: &HashSet<String>,
116 class_safety: SafetyMode,
117) -> Vec<String> {
118 let mut errors = Vec::new();
119
120 if class_safety != SafetyMode::Safe {
122 return errors;
123 }
124
125 if class.base_classes.is_empty() {
127 return errors;
128 }
129
130 debug_println!("INHERITANCE: Checking safe inheritance for class '{}'", class.name);
131
132 for base in &class.base_classes {
133 let base_name = strip_template_params(base);
135
136 if !is_interface_base(&base_name, interfaces, &class.name) {
138 errors.push(format!(
139 "In @safe code, class '{}' can only inherit from @interface classes. \
140 '{}' is not an @interface. Use @unsafe context for regular inheritance.",
141 class.name, base
142 ));
143 }
144 }
145
146 errors
147}
148
149pub fn check_method_safety_contracts(
155 class: &Class,
156 interfaces: &HashMap<String, Class>,
157) -> Vec<String> {
158 let mut errors = Vec::new();
159
160 for base_name in &class.base_classes {
161 let base_stripped = strip_template_params(base_name);
162
163 let interface = match find_matching_interface(&base_stripped, interfaces, &class.name) {
165 Some(i) => i,
166 None => continue, };
168
169 debug_println!("INHERITANCE: Checking method safety contracts for '{}' implementing '{}'",
170 class.name, base_name);
171
172 for interface_method in &interface.methods {
174 let method_name_only = interface_method.name.split("::").last()
178 .unwrap_or(&interface_method.name);
179
180 if method_name_only.starts_with('~') || method_name_only == interface.name {
181 continue;
182 }
183
184 let interface_method_name = interface_method.name.split("::").last()
188 .unwrap_or(&interface_method.name);
189
190 let interface_param_types: Vec<&str> = interface_method.parameters.iter()
192 .map(|p| p.type_name.as_str())
193 .collect();
194
195 let impl_method = class.methods.iter()
196 .find(|m| {
197 let impl_name = m.name.split("::").last().unwrap_or(&m.name);
198 if impl_name != interface_method_name {
199 return false;
200 }
201 if m.parameters.len() != interface_method.parameters.len() {
203 return false;
204 }
205 m.parameters.iter().zip(interface_param_types.iter()).all(|(impl_param, iface_type)| {
207 let impl_type = impl_param.type_name.as_str();
208 let impl_type_base = impl_type.split("::").last().unwrap_or(impl_type);
210 let iface_type_base = iface_type.split("::").last().unwrap_or(iface_type);
211 impl_type_base == iface_type_base || impl_type == *iface_type
212 })
213 });
214
215 let Some(impl_method) = impl_method else { continue };
216
217 debug_println!("INHERITANCE: Found implementation of '{}' in '{}'",
218 interface_method.name, class.name);
219
220 let interface_safety = interface_method.safety_annotation
222 .unwrap_or(SafetyMode::Unsafe);
223
224 if impl_method.has_explicit_safety_annotation {
226 let impl_safety = impl_method.safety_annotation.unwrap_or(SafetyMode::Unsafe);
227
228 if impl_safety != interface_safety {
229 errors.push(format!(
230 "Method '{}::{}' annotated @{} but interface '{}' requires @{}",
231 class.name,
232 interface_method_name,
233 safety_mode_str(impl_safety),
234 strip_template_params(base_name),
235 safety_mode_str(interface_safety)
236 ));
237 }
238 }
239 let effective_safety = if impl_method.has_explicit_safety_annotation {
243 impl_method.safety_annotation.unwrap_or(SafetyMode::Unsafe)
244 } else {
245 interface_safety };
247
248 if effective_safety == SafetyMode::Safe {
250 let body_errors = validate_safe_method_body(impl_method, class, base_name);
251 errors.extend(body_errors);
252 }
253 }
254 }
255
256 errors
257}
258
259fn safety_mode_str(mode: SafetyMode) -> &'static str {
261 match mode {
262 SafetyMode::Safe => "safe",
263 SafetyMode::Unsafe => "unsafe",
264 }
265}
266
267fn validate_safe_method_body(
269 method: &crate::parser::ast_visitor::Function,
270 class: &Class,
271 interface_name: &str,
272) -> Vec<String> {
273 let mut errors = Vec::new();
274
275 let method_name = method.name.split("::").last().unwrap_or(&method.name);
276
277 for stmt in &method.body {
279 let stmt_errors = check_statement_safety(stmt, method_name, &class.name, interface_name);
280 errors.extend(stmt_errors);
281 }
282
283 errors
284}
285
286fn check_statement_safety(
288 stmt: &crate::parser::Statement,
289 method_name: &str,
290 class_name: &str,
291 interface_name: &str,
292) -> Vec<String> {
293 use crate::parser::Statement;
294
295 let mut errors = Vec::new();
296
297 match stmt {
298 Statement::ExpressionStatement { expr, .. } => {
299 let expr_errors = check_expression_safety(expr, method_name, class_name, interface_name);
300 errors.extend(expr_errors);
301 }
302 Statement::VariableDecl(_) => {
303 }
305 Statement::Assignment { lhs, rhs, .. } => {
306 errors.extend(check_expression_safety(lhs, method_name, class_name, interface_name));
307 errors.extend(check_expression_safety(rhs, method_name, class_name, interface_name));
308 }
309 Statement::ReferenceBinding { target, .. } => {
310 errors.extend(check_expression_safety(target, method_name, class_name, interface_name));
311 }
312 Statement::Return(Some(expr)) => {
313 let expr_errors = check_expression_safety(expr, method_name, class_name, interface_name);
314 errors.extend(expr_errors);
315 }
316 Statement::Return(None) => {}
317 Statement::FunctionCall { args, .. } => {
318 for arg in args {
319 errors.extend(check_expression_safety(arg, method_name, class_name, interface_name));
320 }
321 }
322 Statement::If { condition, then_branch, else_branch, .. } => {
323 let cond_errors = check_expression_safety(condition, method_name, class_name, interface_name);
324 errors.extend(cond_errors);
325 for s in then_branch {
326 errors.extend(check_statement_safety(s, method_name, class_name, interface_name));
327 }
328 if let Some(else_stmts) = else_branch {
329 for s in else_stmts {
330 errors.extend(check_statement_safety(s, method_name, class_name, interface_name));
331 }
332 }
333 }
334 Statement::Block(stmts) => {
335 for s in stmts {
336 errors.extend(check_statement_safety(s, method_name, class_name, interface_name));
337 }
338 }
339 Statement::LambdaExpr { .. } => {
340 }
342 Statement::PackExpansion { .. } => {
343 }
345 Statement::EnterScope | Statement::ExitScope |
347 Statement::EnterLoop | Statement::ExitLoop |
348 Statement::EnterUnsafe | Statement::ExitUnsafe => {}
349 }
350
351 errors
352}
353
354fn check_expression_safety(
356 expr: &crate::parser::Expression,
357 method_name: &str,
358 class_name: &str,
359 interface_name: &str,
360) -> Vec<String> {
361 use crate::parser::Expression;
362
363 let mut errors = Vec::new();
364
365 match expr {
366 Expression::Dereference(inner) => {
367 errors.push(format!(
369 "Method '{}::{}' violates @safe contract from interface '{}': pointer dereference in @safe context",
370 class_name, method_name, strip_template_params(interface_name)
371 ));
372 errors.extend(check_expression_safety(inner, method_name, class_name, interface_name));
374 }
375 Expression::AddressOf(inner) => {
376 errors.push(format!(
378 "Method '{}::{}' violates @safe contract from interface '{}': address-of operator in @safe context",
379 class_name, method_name, strip_template_params(interface_name)
380 ));
381 errors.extend(check_expression_safety(inner, method_name, class_name, interface_name));
382 }
383 Expression::FunctionCall { args, .. } => {
384 for arg in args {
386 errors.extend(check_expression_safety(arg, method_name, class_name, interface_name));
387 }
388 }
389 Expression::BinaryOp { left, right, .. } => {
390 errors.extend(check_expression_safety(left, method_name, class_name, interface_name));
391 errors.extend(check_expression_safety(right, method_name, class_name, interface_name));
392 }
393 Expression::MemberAccess { object, .. } => {
394 errors.extend(check_expression_safety(object, method_name, class_name, interface_name));
395 }
396 Expression::Cast(inner) => {
397 errors.extend(check_expression_safety(inner, method_name, class_name, interface_name));
400 }
401 Expression::Move { inner, .. } => {
402 errors.extend(check_expression_safety(inner, method_name, class_name, interface_name));
403 }
404 Expression::Lambda { .. } => {
405 }
407 Expression::Variable(_) | Expression::Literal(_) | Expression::StringLiteral(_) => {}
410 }
411
412 errors
413}
414
415pub fn collect_interfaces(classes: &[Class]) -> HashSet<String> {
417 classes.iter()
418 .filter(|c| c.is_interface)
419 .map(|c| c.name.clone())
420 .collect()
421}
422
423pub fn collect_interface_map(classes: &[Class]) -> HashMap<String, Class> {
425 classes.iter()
426 .filter(|c| c.is_interface)
427 .map(|c| (c.name.clone(), c.clone()))
428 .collect()
429}
430
431pub fn check_inheritance_safety(classes: &[Class]) -> Vec<String> {
433 let mut errors = Vec::new();
434
435 let interfaces = collect_interfaces(classes);
437 let interface_map = collect_interface_map(classes);
438
439 debug_println!("INHERITANCE: Found {} @interface classes: {:?}",
440 interfaces.len(), interfaces);
441
442 for class in classes {
444 errors.extend(validate_interface(class));
445 }
446
447 for class in classes {
449 errors.extend(validate_interface_inheritance(class, &interfaces));
450 }
451
452 for class in classes {
454 let class_safety = class.safety_annotation.unwrap_or(SafetyMode::Unsafe);
456 errors.extend(check_safe_inheritance(class, &interfaces, class_safety));
457 }
458
459 for class in classes {
461 errors.extend(check_method_safety_contracts(class, &interface_map));
462 }
463
464 for class in classes {
466 errors.extend(check_safe_class_copy_semantics(class));
467 }
468
469 errors
470}
471
472fn strip_template_params(type_name: &str) -> String {
475 if let Some(pos) = type_name.find('<') {
476 type_name[..pos].to_string()
477 } else {
478 type_name.to_string()
479 }
480}
481
482fn resolve_qualified_name(base_name: &str, derived_class_name: &str) -> String {
492 if base_name.contains("::") {
494 return base_name.to_string();
495 }
496
497 if let Some(pos) = derived_class_name.rfind("::") {
501 let namespace_prefix = &derived_class_name[..=pos+1]; format!("{}{}", namespace_prefix, base_name)
503 } else {
504 base_name.to_string()
506 }
507}
508
509fn is_interface_base(
514 base_name: &str,
515 interfaces: &HashSet<String>,
516 derived_class_name: &str,
517) -> bool {
518 let qualified_base = resolve_qualified_name(base_name, derived_class_name);
519 interfaces.contains(&qualified_base)
520}
521
522fn find_matching_interface<'a>(
526 base_name: &str,
527 interface_map: &'a HashMap<String, Class>,
528 derived_class_name: &str,
529) -> Option<&'a Class> {
530 let qualified_base = resolve_qualified_name(base_name, derived_class_name);
531 interface_map.get(&qualified_base)
532}
533
534pub fn check_safe_class_copy_semantics(class: &Class) -> Vec<String> {
545 let mut errors = Vec::new();
546
547 let class_safety = class.safety_annotation.unwrap_or(SafetyMode::Unsafe);
549 if class_safety != SafetyMode::Safe {
550 return errors;
551 }
552
553 debug_println!("COPY CHECK: Checking copy semantics for @safe class '{}'", class.name);
554
555 if class.has_copy_constructor && !class.copy_constructor_deleted {
557 errors.push(format!(
558 "@safe class '{}' cannot have a copy constructor. \
559 Use '= delete' to disable copying, or mark the class @unsafe. \
560 Rust-like move semantics require types to be moved, not copied.",
561 class.name
562 ));
563 }
564
565 if class.has_copy_assignment && !class.copy_assignment_deleted {
567 errors.push(format!(
568 "@safe class '{}' cannot have a copy assignment operator. \
569 Use '= delete' to disable copying, or mark the class @unsafe. \
570 Rust-like move semantics require types to be moved, not copied.",
571 class.name
572 ));
573 }
574
575 errors
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use crate::parser::ast_visitor::{Class, SourceLocation};
582
583 fn make_location() -> SourceLocation {
584 SourceLocation {
585 file: "test.cpp".to_string(),
586 line: 1,
587 column: 1,
588 }
589 }
590
591 fn make_interface(name: &str) -> Class {
592 Class {
593 name: name.to_string(),
594 template_parameters: Vec::new(),
595 is_template: false,
596 members: Vec::new(),
597 methods: Vec::new(),
598 base_classes: Vec::new(),
599 location: make_location(),
600 has_destructor: true,
601 is_interface: true,
602 has_virtual_destructor: true,
603 destructor_is_defaulted: true,
604 all_methods_pure_virtual: true,
605 has_non_virtual_methods: false,
606 safety_annotation: None,
607 has_copy_constructor: false,
608 has_copy_assignment: false,
609 copy_constructor_deleted: false,
610 copy_assignment_deleted: false,
611 }
612 }
613
614 fn make_class(name: &str, base_classes: Vec<String>) -> Class {
615 Class {
616 name: name.to_string(),
617 template_parameters: Vec::new(),
618 is_template: false,
619 members: Vec::new(),
620 methods: Vec::new(),
621 base_classes,
622 location: make_location(),
623 has_destructor: false,
624 is_interface: false,
625 has_virtual_destructor: false,
626 destructor_is_defaulted: false,
627 all_methods_pure_virtual: false,
628 has_non_virtual_methods: false,
629 safety_annotation: None,
630 has_copy_constructor: false,
631 has_copy_assignment: false,
632 copy_constructor_deleted: false,
633 copy_assignment_deleted: false,
634 }
635 }
636
637 #[test]
638 fn test_valid_interface() {
639 let interface = make_interface("IDrawable");
640 let errors = validate_interface(&interface);
641 assert!(errors.is_empty(), "Valid interface should have no errors");
642 }
643
644 #[test]
645 fn test_interface_with_data_member() {
646 let mut interface = make_interface("IBadInterface");
647 interface.members.push(crate::parser::ast_visitor::Variable {
648 name: "data".to_string(),
649 type_name: "int".to_string(),
650 is_reference: false,
651 is_pointer: false,
652 is_const: false,
653 is_unique_ptr: false,
654 is_shared_ptr: false,
655 is_static: false,
656 is_mutable: false,
657 location: make_location(),
658 is_pack: false,
659 pack_element_type: None,
660 });
661
662 let errors = validate_interface(&interface);
663 assert_eq!(errors.len(), 1);
664 assert!(errors[0].contains("cannot have data members"));
665 }
666
667 #[test]
668 fn test_interface_with_non_virtual_destructor() {
669 let mut interface = make_interface("IBadInterface");
670 interface.has_destructor = true;
671 interface.has_virtual_destructor = false;
672
673 let errors = validate_interface(&interface);
674 assert_eq!(errors.len(), 1);
675 assert!(errors[0].contains("virtual destructor"));
676 }
677
678 #[test]
679 fn test_interface_without_any_destructor() {
680 let mut interface = make_interface("IBadInterface");
682 interface.has_destructor = false;
683 interface.has_virtual_destructor = false;
684 interface.destructor_is_defaulted = false;
685
686 let errors = validate_interface(&interface);
687 assert_eq!(errors.len(), 1);
688 assert!(errors[0].contains("virtual destructor"));
689 }
690
691 #[test]
692 fn test_interface_with_non_defaulted_virtual_destructor() {
693 let mut interface = make_interface("IBadInterface");
695 interface.has_destructor = true;
696 interface.has_virtual_destructor = true;
697 interface.destructor_is_defaulted = false;
698
699 let errors = validate_interface(&interface);
700 assert_eq!(errors.len(), 1);
701 assert!(errors[0].contains("must be defaulted"));
702 }
703
704 #[test]
705 fn test_safe_inheritance_from_interface() {
706 let interface = make_interface("IDrawable");
707 let mut derived = make_class("Circle", vec!["IDrawable".to_string()]);
708 derived.safety_annotation = Some(SafetyMode::Safe);
709
710 let interfaces: HashSet<String> = vec!["IDrawable".to_string()].into_iter().collect();
711
712 let errors = check_safe_inheritance(&derived, &interfaces, SafetyMode::Safe);
713 assert!(errors.is_empty(), "Safe inheritance from interface should be allowed");
714 }
715
716 #[test]
717 fn test_safe_inheritance_from_non_interface() {
718 let base = make_class("Base", Vec::new());
719 let mut derived = make_class("Derived", vec!["Base".to_string()]);
720 derived.safety_annotation = Some(SafetyMode::Safe);
721
722 let interfaces: HashSet<String> = HashSet::new();
723
724 let errors = check_safe_inheritance(&derived, &interfaces, SafetyMode::Safe);
725 assert_eq!(errors.len(), 1);
726 assert!(errors[0].contains("can only inherit from @interface"));
727 }
728
729 #[test]
734 fn test_resolve_qualified_name_already_qualified() {
735 assert_eq!(
737 resolve_qualified_name("other::IDrawable", "myapp::Circle"),
738 "other::IDrawable"
739 );
740 assert_eq!(
741 resolve_qualified_name("foo::bar::IDrawable", "myapp::Circle"),
742 "foo::bar::IDrawable"
743 );
744 }
745
746 #[test]
747 fn test_resolve_qualified_name_from_namespace() {
748 assert_eq!(
750 resolve_qualified_name("IDrawable", "myapp::Circle"),
751 "myapp::IDrawable"
752 );
753 assert_eq!(
755 resolve_qualified_name("IDrawable", "myapp::inner::Circle"),
756 "myapp::inner::IDrawable"
757 );
758 }
759
760 #[test]
761 fn test_resolve_qualified_name_no_namespace() {
762 assert_eq!(
764 resolve_qualified_name("IDrawable", "Circle"),
765 "IDrawable"
766 );
767 }
768
769 #[test]
770 fn test_is_interface_base_exact_match() {
771 let interfaces: HashSet<String> = vec![
772 "myapp::IDrawable".to_string(),
773 "myapp::ISerializable".to_string(),
774 ].into_iter().collect();
775
776 assert!(is_interface_base("IDrawable", &interfaces, "myapp::Circle"));
778 assert!(is_interface_base("ISerializable", &interfaces, "myapp::Widget"));
779
780 assert!(!is_interface_base("NonExistent", &interfaces, "myapp::Circle"));
782 }
783
784 #[test]
785 fn test_is_interface_base_wrong_namespace_no_match() {
786 let interfaces: HashSet<String> = vec![
787 "myapp::IDrawable".to_string(),
788 ].into_iter().collect();
789
790 assert!(!is_interface_base("IDrawable", &interfaces, "other::Circle"));
793
794 assert!(!is_interface_base("other::IDrawable", &interfaces, "myapp::Circle"));
796 }
797
798 #[test]
799 fn test_safe_inheritance_with_namespaced_interface() {
800 let _interface = make_interface("myapp::IDrawable");
802
803 let mut derived = make_class("myapp::Circle", vec!["IDrawable".to_string()]);
805 derived.safety_annotation = Some(SafetyMode::Safe);
806
807 let interfaces: HashSet<String> = vec!["myapp::IDrawable".to_string()].into_iter().collect();
808
809 let errors = check_safe_inheritance(&derived, &interfaces, SafetyMode::Safe);
810 assert!(errors.is_empty(), "Safe inheritance from namespaced interface should be allowed. Errors: {:?}", errors);
811 }
812
813 #[test]
814 fn test_safe_inheritance_wrong_namespace_fails() {
815 let _interface = make_interface("myapp::IDrawable");
817
818 let mut derived = make_class("other::Circle", vec!["IDrawable".to_string()]);
820 derived.safety_annotation = Some(SafetyMode::Safe);
821
822 let interfaces: HashSet<String> = vec!["myapp::IDrawable".to_string()].into_iter().collect();
824
825 let errors = check_safe_inheritance(&derived, &interfaces, SafetyMode::Safe);
826 assert_eq!(errors.len(), 1, "Should fail: IDrawable resolves to other::IDrawable which is not an interface");
827 assert!(errors[0].contains("not an @interface"));
828 }
829
830 #[test]
831 fn test_interface_inheritance_with_namespace() {
832 let mut child_interface = make_interface("myapp::IExtendedDrawable");
834 child_interface.base_classes = vec!["IDrawable".to_string()];
835
836 let interfaces: HashSet<String> = vec![
837 "myapp::IDrawable".to_string(),
838 "myapp::IExtendedDrawable".to_string(),
839 ].into_iter().collect();
840
841 let errors = validate_interface_inheritance(&child_interface, &interfaces);
842 assert!(errors.is_empty(), "Interface inheritance from namespaced interface should be allowed. Errors: {:?}", errors);
843 }
844}