Skip to main content

treesitter_types/codegen/
type_mapper.rs

1use super::grammar_ir::{FieldInfo, NodeType, TypeRef};
2use super::name_mangler;
3use proc_macro2::Ident;
4
5/// What Rust type to emit for a top-level node.
6#[derive(Debug, Clone)]
7pub enum TypeDecision {
8    /// A named node with fields and/or children → emit a struct.
9    Struct(StructDef),
10    /// A named terminal node (no fields, no children, no subtypes) → emit a leaf struct.
11    LeafStruct(LeafStructDef),
12    /// A supertype node (has `subtypes`) → emit an enum.
13    SupertypeEnum(SupertypeEnumDef),
14}
15
16#[derive(Debug, Clone)]
17pub struct StructDef {
18    pub type_name: Ident,
19    pub kind: String,
20    pub fields: Vec<FieldDef>,
21    pub children: Option<ChildrenDef>,
22}
23
24#[derive(Debug, Clone)]
25pub struct LeafStructDef {
26    pub type_name: Ident,
27    pub kind: String,
28}
29
30#[derive(Debug, Clone)]
31pub struct SupertypeEnumDef {
32    pub type_name: Ident,
33    pub kind: String,
34    pub variants: Vec<VariantDef>,
35}
36
37#[derive(Debug, Clone)]
38pub struct FieldDef {
39    pub field_name: Ident,
40    /// The original field name string (for `child_by_field_name`).
41    pub raw_field_name: String,
42    pub field_type: FieldType,
43}
44
45#[derive(Debug, Clone)]
46pub struct ChildrenDef {
47    pub field_type: FieldType,
48}
49
50/// How to represent a field's type in Rust.
51#[derive(Debug, Clone)]
52pub enum FieldType {
53    /// A single concrete type (e.g., `Identifier<'tree>`).
54    Direct(TypeReference),
55    /// Wrapped in `Option<T>`.
56    Optional(TypeReference),
57    /// Wrapped in `Vec<T>`.
58    Repeated(TypeReference),
59}
60
61/// Reference to a Rust type, which may be a single named type or an alternation enum.
62#[derive(Debug, Clone)]
63pub enum TypeReference {
64    /// A single named node type.
65    Named(Ident),
66    /// An alternation enum generated for this field.
67    Alternation(AlternationEnumDef),
68}
69
70#[derive(Debug, Clone)]
71pub struct AlternationEnumDef {
72    pub type_name: Ident,
73    pub variants: Vec<VariantDef>,
74}
75
76#[derive(Debug, Clone)]
77pub struct VariantDef {
78    pub variant_name: Ident,
79    pub kind: String,
80    pub named: bool,
81    /// True if this variant references a supertype (has subtypes in the grammar).
82    /// Supertypes are abstract — tree-sitter never emits them as node kinds at runtime.
83    pub is_supertype: bool,
84}
85
86/// Maps all `NodeType` entries into `TypeDecision`s.
87pub fn map_types(nodes: &[NodeType]) -> Vec<TypeDecision> {
88    // Collect the set of type names that are supertypes
89    let supertype_kinds: std::collections::HashSet<&str> = nodes
90        .iter()
91        .filter(|n| n.subtypes.is_some())
92        .map(|n| n.type_name.as_str())
93        .collect();
94
95    nodes
96        .iter()
97        .filter(|n| n.named)
98        .map(|n| map_node(n, &supertype_kinds))
99        .collect()
100}
101
102fn map_node(node: &NodeType, supertype_kinds: &std::collections::HashSet<&str>) -> TypeDecision {
103    let raw_kind = &node.type_name;
104
105    // Supertype nodes (e.g., _expression, statement) → enum
106    if let Some(subtypes) = &node.subtypes {
107        return TypeDecision::SupertypeEnum(SupertypeEnumDef {
108            type_name: supertype_ident(raw_kind),
109            kind: raw_kind.clone(),
110            variants: subtypes
111                .iter()
112                .map(|tr| make_variant_def(tr, supertype_kinds))
113                .collect(),
114        });
115    }
116
117    // Named node with fields or children → struct
118    if !node.fields.is_empty() || node.children.is_some() {
119        let type_name = name_mangler::type_ident(raw_kind);
120        let fields = node
121            .fields
122            .iter()
123            .map(|(field_name, field_info)| {
124                let parent_name = type_name.to_string();
125                map_field(field_name, field_info, &parent_name, supertype_kinds)
126            })
127            .collect();
128        let children = node.children.as_ref().map(|c| {
129            let parent_name = type_name.to_string();
130            map_children(c, &parent_name, supertype_kinds)
131        });
132        return TypeDecision::Struct(StructDef {
133            type_name,
134            kind: raw_kind.clone(),
135            fields,
136            children,
137        });
138    }
139
140    // Leaf node (named, no fields, no children, no subtypes)
141    TypeDecision::LeafStruct(LeafStructDef {
142        type_name: name_mangler::type_ident(raw_kind),
143        kind: raw_kind.clone(),
144    })
145}
146
147fn make_variant_def(tr: &TypeRef, supertype_kinds: &std::collections::HashSet<&str>) -> VariantDef {
148    VariantDef {
149        variant_name: name_mangler::variant_name(&tr.type_name, tr.named),
150        kind: tr.type_name.clone(),
151        named: tr.named,
152        is_supertype: supertype_kinds.contains(tr.type_name.as_str()),
153    }
154}
155
156fn map_field(
157    field_name: &str,
158    field_info: &FieldInfo,
159    parent_name: &str,
160    supertype_kinds: &std::collections::HashSet<&str>,
161) -> FieldDef {
162    let type_ref = map_type_reference(&field_info.types, parent_name, field_name, supertype_kinds);
163    let field_type = match (field_info.required, field_info.multiple) {
164        (_, true) => FieldType::Repeated(type_ref),
165        (false, false) => FieldType::Optional(type_ref),
166        (true, false) => FieldType::Direct(type_ref),
167    };
168
169    FieldDef {
170        field_name: name_mangler::field_ident(field_name),
171        raw_field_name: field_name.to_owned(),
172        field_type,
173    }
174}
175
176fn map_children(
177    children: &FieldInfo,
178    parent_name: &str,
179    supertype_kinds: &std::collections::HashSet<&str>,
180) -> ChildrenDef {
181    let type_ref = map_type_reference(&children.types, parent_name, "children", supertype_kinds);
182    let field_type = match (children.required, children.multiple) {
183        (_, true) => FieldType::Repeated(type_ref),
184        (false, false) => FieldType::Optional(type_ref),
185        (true, false) => FieldType::Direct(type_ref),
186    };
187    ChildrenDef { field_type }
188}
189
190fn map_type_reference(
191    types: &[TypeRef],
192    parent_name: &str,
193    field_name: &str,
194    supertype_kinds: &std::collections::HashSet<&str>,
195) -> TypeReference {
196    // Filter to only named types for the type reference.
197    // Anonymous nodes (punctuation) in field types are unusual but can appear in alternations.
198    if types.len() == 1 && types[0].named {
199        TypeReference::Named(name_mangler::type_ident(&types[0].type_name))
200    } else {
201        // Multiple types or contains anonymous → alternation enum
202        let enum_name = format!(
203            "{}{}",
204            parent_name,
205            name_mangler::to_pascal_case(field_name)
206        );
207        TypeReference::Alternation(AlternationEnumDef {
208            type_name: quote::format_ident!("{}", enum_name),
209            variants: types
210                .iter()
211                .map(|tr| make_variant_def(tr, supertype_kinds))
212                .collect(),
213        })
214    }
215}
216
217/// Supertype nodes start with `_` (e.g., `_expression`). Strip the prefix for the type name.
218fn supertype_ident(kind: &str) -> Ident {
219    let stripped = kind.strip_prefix('_').unwrap_or(kind);
220    name_mangler::type_ident(stripped)
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use crate::codegen::grammar_ir::parse_node_types;
227
228    #[test]
229    fn test_leaf_node_maps_to_leaf_struct() {
230        let nodes = parse_node_types(r#"[{"type": "identifier", "named": true}]"#).unwrap();
231        let decisions = map_types(&nodes);
232        assert_eq!(decisions.len(), 1);
233        assert!(
234            matches!(&decisions[0], TypeDecision::LeafStruct(def) if def.type_name == "Identifier")
235        );
236    }
237
238    #[test]
239    fn test_unnamed_nodes_are_skipped() {
240        let nodes = parse_node_types(r#"[{"type": ".", "named": false}]"#).unwrap();
241        let decisions = map_types(&nodes);
242        assert!(decisions.is_empty());
243    }
244
245    #[test]
246    fn test_node_with_fields_maps_to_struct() {
247        let json = r#"[{
248            "type": "import_spec",
249            "named": true,
250            "fields": {
251                "path": {
252                    "multiple": false,
253                    "required": true,
254                    "types": [{"type": "interpreted_string_literal", "named": true}]
255                }
256            }
257        }]"#;
258        let nodes = parse_node_types(json).unwrap();
259        let decisions = map_types(&nodes);
260        assert_eq!(decisions.len(), 1);
261        let TypeDecision::Struct(def) = &decisions[0] else {
262            panic!("expected Struct");
263        };
264        assert_eq!(def.type_name.to_string(), "ImportSpec");
265        assert_eq!(def.fields.len(), 1);
266        assert_eq!(def.fields[0].field_name.to_string(), "path");
267        assert!(matches!(&def.fields[0].field_type, FieldType::Direct(_)));
268    }
269
270    #[test]
271    fn test_optional_field() {
272        let json = r#"[{
273            "type": "import_spec",
274            "named": true,
275            "fields": {
276                "name": {
277                    "multiple": false,
278                    "required": false,
279                    "types": [{"type": "identifier", "named": true}]
280                }
281            }
282        }]"#;
283        let nodes = parse_node_types(json).unwrap();
284        let decisions = map_types(&nodes);
285        let TypeDecision::Struct(def) = &decisions[0] else {
286            panic!("expected Struct");
287        };
288        assert!(matches!(&def.fields[0].field_type, FieldType::Optional(_)));
289    }
290
291    #[test]
292    fn test_repeated_field() {
293        let json = r#"[{
294            "type": "block",
295            "named": true,
296            "fields": {
297                "statements": {
298                    "multiple": true,
299                    "required": false,
300                    "types": [{"type": "statement", "named": true}]
301                }
302            }
303        }]"#;
304        let nodes = parse_node_types(json).unwrap();
305        let decisions = map_types(&nodes);
306        let TypeDecision::Struct(def) = &decisions[0] else {
307            panic!("expected Struct");
308        };
309        assert!(matches!(&def.fields[0].field_type, FieldType::Repeated(_)));
310    }
311
312    #[test]
313    fn test_alternation_field() {
314        let json = r#"[{
315            "type": "import_spec",
316            "named": true,
317            "fields": {
318                "name": {
319                    "multiple": false,
320                    "required": false,
321                    "types": [
322                        {"type": ".", "named": false},
323                        {"type": "identifier", "named": true}
324                    ]
325                }
326            }
327        }]"#;
328        let nodes = parse_node_types(json).unwrap();
329        let decisions = map_types(&nodes);
330        let TypeDecision::Struct(def) = &decisions[0] else {
331            panic!("expected Struct");
332        };
333        let FieldType::Optional(TypeReference::Alternation(alt)) = &def.fields[0].field_type else {
334            panic!("expected Optional(Alternation)");
335        };
336        assert_eq!(alt.type_name.to_string(), "ImportSpecName");
337        assert_eq!(alt.variants.len(), 2);
338        assert_eq!(alt.variants[0].variant_name.to_string(), "Dot");
339        assert_eq!(alt.variants[1].variant_name.to_string(), "Identifier");
340    }
341
342    #[test]
343    fn test_supertype_maps_to_enum() {
344        let json = r#"[{
345            "type": "_expression",
346            "named": true,
347            "subtypes": [
348                {"type": "binary_expression", "named": true},
349                {"type": "identifier", "named": true}
350            ]
351        }]"#;
352        let nodes = parse_node_types(json).unwrap();
353        let decisions = map_types(&nodes);
354        let TypeDecision::SupertypeEnum(def) = &decisions[0] else {
355            panic!("expected SupertypeEnum");
356        };
357        assert_eq!(def.type_name.to_string(), "Expression");
358        assert_eq!(def.variants.len(), 2);
359    }
360
361    #[test]
362    fn test_node_with_children() {
363        let json = r#"[{
364            "type": "import_spec_list",
365            "named": true,
366            "fields": {},
367            "children": {
368                "multiple": true,
369                "required": false,
370                "types": [{"type": "import_spec", "named": true}]
371            }
372        }]"#;
373        let nodes = parse_node_types(json).unwrap();
374        let decisions = map_types(&nodes);
375        let TypeDecision::Struct(def) = &decisions[0] else {
376            panic!("expected Struct");
377        };
378        assert!(def.children.is_some());
379        assert!(matches!(
380            &def.children.as_ref().unwrap().field_type,
381            FieldType::Repeated(_)
382        ));
383    }
384}