Skip to main content

shape_ast/transform/
comptime_extends.rs

1//! Comptime annotation expansion utilities.
2//!
3//! This transform extracts direct `extend ... { ... }` directives from
4//! `annotation ... comptime pre/post(...) { ... }` handler bodies and materializes
5//! them as synthetic top-level `Item::Extend` entries. It is intentionally
6//! static/AST-driven and does not execute comptime code.
7
8use std::collections::HashMap;
9
10use crate::ast::{
11    Annotation, AnnotationHandler, AnnotationHandlerType, Expr, Item, MethodDef, Program, Span,
12    Statement, TypeName,
13};
14
15/// Return a cloned program augmented with synthetic `Item::Extend` items
16/// derived from direct comptime handler directives on annotated targets.
17pub fn augment_program_with_generated_extends(program: &Program) -> Program {
18    let mut augmented = program.clone();
19    let generated = collect_generated_annotation_extends(program);
20    for extend in generated {
21        augmented.items.push(Item::Extend(extend, Span::DUMMY));
22    }
23    augmented
24}
25
26/// Collect synthetic extends generated by directly declared annotation comptime handlers.
27pub fn collect_generated_annotation_extends(program: &Program) -> Vec<crate::ast::ExtendStatement> {
28    let mut comptime_handlers: HashMap<String, Vec<AnnotationHandler>> = HashMap::new();
29    for item in &program.items {
30        if let Item::AnnotationDef(ann_def, _) = item {
31            let handlers: Vec<_> = ann_def
32                .handlers
33                .iter()
34                .filter(|h| {
35                    matches!(
36                        h.handler_type,
37                        AnnotationHandlerType::ComptimePre | AnnotationHandlerType::ComptimePost
38                    )
39                })
40                .cloned()
41                .collect();
42            if !handlers.is_empty() {
43                comptime_handlers.insert(ann_def.name.clone(), handlers);
44            }
45        }
46    }
47
48    if comptime_handlers.is_empty() {
49        return Vec::new();
50    }
51
52    let mut methods_by_type: HashMap<String, Vec<MethodDef>> = HashMap::new();
53    collect_generated_annotation_extends_from_items(
54        &program.items,
55        &comptime_handlers,
56        &mut methods_by_type,
57    );
58
59    methods_by_type
60        .into_iter()
61        .map(|(type_name, methods)| crate::ast::ExtendStatement {
62            type_name: TypeName::Simple(type_name.into()),
63            methods,
64        })
65        .collect()
66}
67
68fn collect_generated_annotation_extends_from_items(
69    items: &[Item],
70    comptime_handlers: &HashMap<String, Vec<AnnotationHandler>>,
71    methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
72) {
73    for item in items {
74        match item {
75            Item::StructType(struct_def, _) => collect_annotation_methods_for_target(
76                &struct_def.annotations,
77                &struct_def.name,
78                comptime_handlers,
79                methods_by_type,
80            ),
81            Item::Function(func_def, _) => collect_annotation_methods_for_target(
82                &func_def.annotations,
83                &func_def.name,
84                comptime_handlers,
85                methods_by_type,
86            ),
87            Item::Module(module_def, _) => collect_generated_annotation_extends_from_items(
88                &module_def.items,
89                comptime_handlers,
90                methods_by_type,
91            ),
92            _ => {}
93        }
94    }
95}
96
97fn collect_annotation_methods_for_target(
98    annotations: &[Annotation],
99    target_name: &str,
100    comptime_handlers: &HashMap<String, Vec<AnnotationHandler>>,
101    methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
102) {
103    for ann in annotations {
104        let Some(handlers) = comptime_handlers.get(&ann.name) else {
105            continue;
106        };
107        for handler in handlers {
108            collect_extend_methods_from_expr(&handler.body, target_name, methods_by_type);
109        }
110    }
111}
112
113fn collect_extend_methods_from_expr(
114    expr: &Expr,
115    target_name: &str,
116    methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
117) {
118    match expr {
119        Expr::Block(block, _) => {
120            for item in &block.items {
121                match item {
122                    crate::ast::BlockItem::Statement(stmt) => {
123                        collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
124                    }
125                    crate::ast::BlockItem::Expression(inner) => {
126                        collect_extend_methods_from_expr(inner, target_name, methods_by_type);
127                    }
128                    _ => {}
129                }
130            }
131        }
132        Expr::Conditional {
133            then_expr,
134            else_expr,
135            ..
136        } => {
137            collect_extend_methods_from_expr(then_expr, target_name, methods_by_type);
138            if let Some(else_expr) = else_expr {
139                collect_extend_methods_from_expr(else_expr, target_name, methods_by_type);
140            }
141        }
142        Expr::Match(match_expr, _) => {
143            for arm in &match_expr.arms {
144                collect_extend_methods_from_expr(&arm.body, target_name, methods_by_type);
145            }
146        }
147        Expr::Annotated { target, .. } => {
148            collect_extend_methods_from_expr(target, target_name, methods_by_type);
149        }
150        Expr::Comptime(stmts, _) => {
151            for stmt in stmts {
152                collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
153            }
154        }
155        _ => {}
156    }
157}
158
159fn collect_extend_methods_from_stmt(
160    stmt: &Statement,
161    target_name: &str,
162    methods_by_type: &mut HashMap<String, Vec<MethodDef>>,
163) {
164    match stmt {
165        Statement::Extend(extend, _) => {
166            let resolved_type = match &extend.type_name {
167                TypeName::Simple(name) if name.as_str() == "target" => target_name.to_string(),
168                TypeName::Generic { name, .. } if name.as_str() == "target" => target_name.to_string(),
169                TypeName::Simple(name) => name.to_string(),
170                TypeName::Generic { name, .. } => name.to_string(),
171            };
172            let entry = methods_by_type.entry(resolved_type).or_default();
173            for method in &extend.methods {
174                if !entry.iter().any(|existing| existing.name == method.name) {
175                    entry.push(method.clone());
176                }
177            }
178        }
179        Statement::If(if_stmt, _) => {
180            for stmt in &if_stmt.then_body {
181                collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
182            }
183            if let Some(else_body) = &if_stmt.else_body {
184                for stmt in else_body {
185                    collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
186                }
187            }
188        }
189        Statement::For(for_loop, _) => {
190            for stmt in &for_loop.body {
191                collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
192            }
193        }
194        Statement::While(while_loop, _) => {
195            for stmt in &while_loop.body {
196                collect_extend_methods_from_stmt(stmt, target_name, methods_by_type);
197            }
198        }
199        Statement::Expression(expr, _) => {
200            collect_extend_methods_from_expr(expr, target_name, methods_by_type);
201        }
202        _ => {}
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::parser::parse_program;
210
211    #[test]
212    fn collects_extend_target_for_annotated_type() {
213        let code = r#"
214annotation add_sum() {
215    targets: [type]
216    comptime post(target, ctx) {
217        extend target {
218            method sum() { self.x + self.y }
219        }
220    }
221}
222@add_sum()
223type Point { x: int, y: int }
224"#;
225        let program = parse_program(code).expect("parse");
226        let generated = collect_generated_annotation_extends(&program);
227        assert_eq!(generated.len(), 1, "expected one generated extend");
228        let ext = &generated[0];
229        match &ext.type_name {
230            TypeName::Simple(name) => assert_eq!(name, "Point"),
231            other => panic!("expected simple type name, got {:?}", other),
232        }
233        assert!(
234            ext.methods.iter().any(|m| m.name == "sum"),
235            "expected generated sum method"
236        );
237    }
238
239    #[test]
240    fn does_not_generate_for_unused_annotation() {
241        let code = r#"
242annotation add_sum() {
243    targets: [type]
244    comptime post(target, ctx) {
245        extend target {
246            method sum() { self.x + self.y }
247        }
248    }
249}
250type Point { x: int, y: int }
251"#;
252        let program = parse_program(code).expect("parse");
253        let generated = collect_generated_annotation_extends(&program);
254        assert!(
255            generated.is_empty(),
256            "unused annotations must not generate extends"
257        );
258    }
259}