1use std::collections::HashMap;
9
10use crate::ast::{
11 Annotation, AnnotationHandler, AnnotationHandlerType, Expr, Item, MethodDef, Program, Span,
12 Statement, TypeName,
13};
14
15pub fn augment_program_with_generated_extends(program: &Program) -> Program {
18 let mut augmented = program.clone();
19 let generated = collect_generated_annotation_extends(program);
20 for extend in generated {
21 augmented.items.push(Item::Extend(extend, Span::DUMMY));
22 }
23 augmented
24}
25
26pub fn collect_generated_annotation_extends(program: &Program) -> Vec<crate::ast::ExtendStatement> {
28 let mut comptime_handlers: HashMap<String, Vec<AnnotationHandler>> = HashMap::new();
29 for item in &program.items {
30 if let Item::AnnotationDef(ann_def, _) = item {
31 let handlers: Vec<_> = ann_def
32 .handlers
33 .iter()
34 .filter(|h| {
35 matches!(
36 h.handler_type,
37 AnnotationHandlerType::ComptimePre | AnnotationHandlerType::ComptimePost
38 )
39 })
40 .cloned()
41 .collect();
42 if !handlers.is_empty() {
43 comptime_handlers.insert(ann_def.name.clone(), handlers);
44 }
45 }
46 }
47
48 if comptime_handlers.is_empty() {
49 return Vec::new();
50 }
51
52 let mut methods_by_type: HashMap<String, Vec<MethodDef>> = HashMap::new();
53 collect_generated_annotation_extends_from_items(
54 &program.items,
55 &comptime_handlers,
56 &mut methods_by_type,
57 );
58
59 methods_by_type
60 .into_iter()
61 .map(|(type_name, methods)| crate::ast::ExtendStatement {
62 type_name: TypeName::Simple(type_name.into()),
63 methods,
64 })
65 .collect()
66}
67
68fn collect_generated_annotation_extends_from_items(
69 items: &[Item],
70 comptime_handlers: &HashMap<String, Vec<AnnotationHandler>>,
71 methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
72) {
73 for item in items {
74 match item {
75 Item::StructType(struct_def, _) => collect_annotation_methods_for_target(
76 &struct_def.annotations,
77 &struct_def.name,
78 comptime_handlers,
79 methods_by_type,
80 ),
81 Item::Function(func_def, _) => collect_annotation_methods_for_target(
82 &func_def.annotations,
83 &func_def.name,
84 comptime_handlers,
85 methods_by_type,
86 ),
87 Item::Module(module_def, _) => collect_generated_annotation_extends_from_items(
88 &module_def.items,
89 comptime_handlers,
90 methods_by_type,
91 ),
92 _ => {}
93 }
94 }
95}
96
97fn collect_annotation_methods_for_target(
98 annotations: &[Annotation],
99 target_name: &str,
100 comptime_handlers: &HashMap<String, Vec<AnnotationHandler>>,
101 methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
102) {
103 for ann in annotations {
104 let Some(handlers) = comptime_handlers.get(&ann.name) else {
105 continue;
106 };
107 for handler in handlers {
108 collect_extend_methods_from_expr(&handler.body, target_name, methods_by_type);
109 }
110 }
111}
112
113fn collect_extend_methods_from_expr(
114 expr: &Expr,
115 target_name: &str,
116 methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
117) {
118 match expr {
119 Expr::Block(block, _) => {
120 for item in &block.items {
121 match item {
122 crate::ast::BlockItem::Statement(stmt) => {
123 collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
124 }
125 crate::ast::BlockItem::Expression(inner) => {
126 collect_extend_methods_from_expr(inner, target_name, methods_by_type);
127 }
128 _ => {}
129 }
130 }
131 }
132 Expr::Conditional {
133 then_expr,
134 else_expr,
135 ..
136 } => {
137 collect_extend_methods_from_expr(then_expr, target_name, methods_by_type);
138 if let Some(else_expr) = else_expr {
139 collect_extend_methods_from_expr(else_expr, target_name, methods_by_type);
140 }
141 }
142 Expr::Match(match_expr, _) => {
143 for arm in &match_expr.arms {
144 collect_extend_methods_from_expr(&arm.body, target_name, methods_by_type);
145 }
146 }
147 Expr::Annotated { target, .. } => {
148 collect_extend_methods_from_expr(target, target_name, methods_by_type);
149 }
150 Expr::Comptime(stmts, _) => {
151 for stmt in stmts {
152 collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
153 }
154 }
155 _ => {}
156 }
157}
158
159fn collect_extend_methods_from_stmt(
160 stmt: &Statement,
161 target_name: &str,
162 methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
163) {
164 match stmt {
165 Statement::Extend(extend, _) => {
166 let resolved_type = match &extend.type_name {
167 TypeName::Simple(name) if name.as_str() == "target" => target_name.to_string(),
168 TypeName::Generic { name, .. } if name.as_str() == "target" => target_name.to_string(),
169 TypeName::Simple(name) => name.to_string(),
170 TypeName::Generic { name, .. } => name.to_string(),
171 };
172 let entry = methods_by_type.entry(resolved_type).or_default();
173 for method in &extend.methods {
174 if !entry.iter().any(|existing| existing.name == method.name) {
175 entry.push(method.clone());
176 }
177 }
178 }
179 Statement::If(if_stmt, _) => {
180 for stmt in &if_stmt.then_body {
181 collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
182 }
183 if let Some(else_body) = &if_stmt.else_body {
184 for stmt in else_body {
185 collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
186 }
187 }
188 }
189 Statement::For(for_loop, _) => {
190 for stmt in &for_loop.body {
191 collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
192 }
193 }
194 Statement::While(while_loop, _) => {
195 for stmt in &while_loop.body {
196 collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
197 }
198 }
199 Statement::Expression(expr, _) => {
200 collect_extend_methods_from_expr(expr, target_name, methods_by_type);
201 }
202 _ => {}
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::parser::parse_program;
210
211 #[test]
212 fn collects_extend_target_for_annotated_type() {
213 let code = r#"
214annotation add_sum() {
215 targets: [type]
216 comptime post(target, ctx) {
217 extend target {
218 method sum() { self.x + self.y }
219 }
220 }
221}
222@add_sum()
223type Point { x: int, y: int }
224"#;
225 let program = parse_program(code).expect("parse");
226 let generated = collect_generated_annotation_extends(&program);
227 assert_eq!(generated.len(), 1, "expected one generated extend");
228 let ext = &generated[0];
229 match &ext.type_name {
230 TypeName::Simple(name) => assert_eq!(name, "Point"),
231 other => panic!("expected simple type name, got {:?}", other),
232 }
233 assert!(
234 ext.methods.iter().any(|m| m.name == "sum"),
235 "expected generated sum method"
236 );
237 }
238
239 #[test]
240 fn does_not_generate_for_unused_annotation() {
241 let code = r#"
242annotation add_sum() {
243 targets: [type]
244 comptime post(target, ctx) {
245 extend target {
246 method sum() { self.x + self.y }
247 }
248 }
249}
250type Point { x: int, y: int }
251"#;
252 let program = parse_program(code).expect("parse");
253 let generated = collect_generated_annotation_extends(&program);
254 assert!(
255 generated.is_empty(),
256 "unused annotations must not generate extends"
257 );
258 }
259}