Skip to main content

treesitter_types/codegen/
type_mapper.rs

1use super::grammar_ir::{FieldInfo, NodeType, TypeRef};
2use super::name_mangler;
3use proc_macro2::Ident;
4
5/// What Rust type to emit for a top-level node.
6#[derive(Debug, Clone)]
7pub enum TypeDecision {
8    /// A named node with fields and/or children → emit a struct.
9    Struct(StructDef),
10    /// A named terminal node (no fields, no children, no subtypes) → emit a leaf struct.
11    LeafStruct(LeafStructDef),
12    /// A supertype node (has `subtypes`) → emit an enum.
13    SupertypeEnum(SupertypeEnumDef),
14}
15
16#[derive(Debug, Clone)]
17pub struct StructDef {
18    pub type_name: Ident,
19    pub kind: String,
20    pub fields: Vec<FieldDef>,
21    pub children: Option<ChildrenDef>,
22}
23
24#[derive(Debug, Clone)]
25pub struct LeafStructDef {
26    pub type_name: Ident,
27    pub kind: String,
28}
29
30#[derive(Debug, Clone)]
31pub struct SupertypeEnumDef {
32    pub type_name: Ident,
33    pub kind: String,
34    pub variants: Vec<VariantDef>,
35}
36
37#[derive(Debug, Clone)]
38pub struct FieldDef {
39    pub field_name: Ident,
40    /// The original field name string (for `child_by_field_name`).
41    pub raw_field_name: String,
42    pub field_type: FieldType,
43}
44
45#[derive(Debug, Clone)]
46pub struct ChildrenDef {
47    pub field_type: FieldType,
48}
49
50/// How to represent a field's type in Rust.
51#[derive(Debug, Clone)]
52pub enum FieldType {
53    /// A single concrete type (e.g., `Identifier<'tree>`).
54    Direct(TypeReference),
55    /// Wrapped in `Option<T>`.
56    Optional(TypeReference),
57    /// Wrapped in `Vec<T>`.
58    Repeated(TypeReference),
59}
60
61/// Reference to a Rust type, which may be a single named type or an alternation enum.
62#[derive(Debug, Clone)]
63pub enum TypeReference {
64    /// A single named node type.
65    Named(Ident),
66    /// An alternation enum generated for this field.
67    Alternation(AlternationEnumDef),
68}
69
70#[derive(Debug, Clone)]
71pub struct AlternationEnumDef {
72    pub type_name: Ident,
73    pub variants: Vec<VariantDef>,
74}
75
76#[derive(Debug, Clone)]
77pub struct VariantDef {
78    pub variant_name: Ident,
79    pub kind: String,
80    pub named: bool,
81    /// True if this variant references a supertype (has subtypes in the grammar).
82    /// Supertypes are abstract — tree-sitter never emits them as node kinds at runtime.
83    pub is_supertype: bool,
84    /// Additional kind strings that map to the same variant name (e.g., "a" and "A" both → `A`).
85    /// Populated during deduplication in the emitter.
86    pub extra_kinds: Vec<String>,
87}
88
89/// Maps all `NodeType` entries into `TypeDecision`s.
90pub fn map_types(nodes: &[NodeType]) -> Vec<TypeDecision> {
91    // Collect the set of type names that are supertypes
92    let supertype_kinds: std::collections::HashSet<&str> = nodes
93        .iter()
94        .filter(|n| n.subtypes.is_some())
95        .map(|n| n.type_name.as_str())
96        .collect();
97
98    // Collect concrete (non-supertype) node kinds for conflict detection
99    let concrete_kinds: std::collections::HashSet<&str> = nodes
100        .iter()
101        .filter(|n| n.named && n.subtypes.is_none())
102        .map(|n| n.type_name.as_str())
103        .collect();
104
105    let mut decisions: Vec<TypeDecision> = nodes
106        .iter()
107        .filter(|n| n.named)
108        .map(|n| map_node(n, &supertype_kinds, &concrete_kinds))
109        .collect();
110
111    // Collect all defined type names
112    let defined_kinds: std::collections::HashSet<String> = nodes
113        .iter()
114        .filter(|n| n.named)
115        .map(|n| n.type_name.clone())
116        .collect();
117
118    // Collect all referenced named types from fields, children, and subtypes
119    let mut referenced_kinds = std::collections::HashSet::new();
120    for node in nodes.iter().filter(|n| n.named) {
121        for field_info in node.fields.values() {
122            for tr in &field_info.types {
123                if tr.named {
124                    referenced_kinds.insert(tr.type_name.clone());
125                }
126            }
127        }
128        if let Some(children) = &node.children {
129            for tr in &children.types {
130                if tr.named {
131                    referenced_kinds.insert(tr.type_name.clone());
132                }
133            }
134        }
135        if let Some(subtypes) = &node.subtypes {
136            for tr in subtypes {
137                if tr.named {
138                    referenced_kinds.insert(tr.type_name.clone());
139                }
140            }
141        }
142    }
143
144    // Generate leaf structs for referenced but undefined types
145    for kind in &referenced_kinds {
146        if !defined_kinds.contains(kind) {
147            decisions.push(TypeDecision::LeafStruct(LeafStructDef {
148                type_name: name_mangler::type_ident(kind),
149                kind: kind.clone(),
150            }));
151        }
152    }
153
154    decisions
155}
156
157fn map_node(
158    node: &NodeType,
159    supertype_kinds: &std::collections::HashSet<&str>,
160    concrete_kinds: &std::collections::HashSet<&str>,
161) -> TypeDecision {
162    let raw_kind = &node.type_name;
163
164    // Supertype nodes (e.g., _expression, statement) → enum
165    if let Some(subtypes) = &node.subtypes {
166        return TypeDecision::SupertypeEnum(SupertypeEnumDef {
167            type_name: supertype_ident(raw_kind, concrete_kinds),
168            kind: raw_kind.clone(),
169            variants: subtypes
170                .iter()
171                .map(|tr| make_variant_def(tr, supertype_kinds))
172                .collect(),
173        });
174    }
175
176    // Named node with fields or children → struct
177    if !node.fields.is_empty() || node.children.is_some() {
178        let type_name = name_mangler::type_ident(raw_kind);
179        let fields = node
180            .fields
181            .iter()
182            .map(|(field_name, field_info)| {
183                let parent_name = type_name.to_string();
184                map_field(
185                    field_name,
186                    field_info,
187                    &parent_name,
188                    supertype_kinds,
189                    concrete_kinds,
190                )
191            })
192            .collect();
193        let children = node.children.as_ref().map(|c| {
194            let parent_name = type_name.to_string();
195            map_children(c, &parent_name, supertype_kinds, concrete_kinds)
196        });
197        return TypeDecision::Struct(StructDef {
198            type_name,
199            kind: raw_kind.clone(),
200            fields,
201            children,
202        });
203    }
204
205    // Leaf node (named, no fields, no children, no subtypes)
206    TypeDecision::LeafStruct(LeafStructDef {
207        type_name: name_mangler::type_ident(raw_kind),
208        kind: raw_kind.clone(),
209    })
210}
211
212fn make_variant_def(tr: &TypeRef, supertype_kinds: &std::collections::HashSet<&str>) -> VariantDef {
213    VariantDef {
214        variant_name: name_mangler::variant_name(&tr.type_name, tr.named),
215        kind: tr.type_name.clone(),
216        named: tr.named,
217        is_supertype: tr.named && supertype_kinds.contains(tr.type_name.as_str()),
218        extra_kinds: Vec::new(),
219    }
220}
221
222fn map_field(
223    field_name: &str,
224    field_info: &FieldInfo,
225    parent_name: &str,
226    supertype_kinds: &std::collections::HashSet<&str>,
227    concrete_kinds: &std::collections::HashSet<&str>,
228) -> FieldDef {
229    let type_ref = map_type_reference(
230        &field_info.types,
231        parent_name,
232        field_name,
233        supertype_kinds,
234        concrete_kinds,
235    );
236    let field_type = match (field_info.required, field_info.multiple) {
237        (_, true) => FieldType::Repeated(type_ref),
238        (false, false) => FieldType::Optional(type_ref),
239        (true, false) => FieldType::Direct(type_ref),
240    };
241
242    FieldDef {
243        field_name: name_mangler::field_ident(field_name),
244        raw_field_name: field_name.to_owned(),
245        field_type,
246    }
247}
248
249fn map_children(
250    children: &FieldInfo,
251    parent_name: &str,
252    supertype_kinds: &std::collections::HashSet<&str>,
253    concrete_kinds: &std::collections::HashSet<&str>,
254) -> ChildrenDef {
255    let type_ref = map_type_reference(
256        &children.types,
257        parent_name,
258        "children",
259        supertype_kinds,
260        concrete_kinds,
261    );
262    let field_type = match (children.required, children.multiple) {
263        (_, true) => FieldType::Repeated(type_ref),
264        (false, false) => FieldType::Optional(type_ref),
265        (true, false) => FieldType::Direct(type_ref),
266    };
267    ChildrenDef { field_type }
268}
269
270fn map_type_reference(
271    types: &[TypeRef],
272    parent_name: &str,
273    field_name: &str,
274    supertype_kinds: &std::collections::HashSet<&str>,
275    concrete_kinds: &std::collections::HashSet<&str>,
276) -> TypeReference {
277    // Filter to only named types for the type reference.
278    // Anonymous nodes (punctuation) in field types are unusual but can appear in alternations.
279    if types.len() == 1 && types[0].named {
280        let ident = if supertype_kinds.contains(types[0].type_name.as_str()) {
281            supertype_ident(&types[0].type_name, concrete_kinds)
282        } else {
283            name_mangler::type_ident(&types[0].type_name)
284        };
285        TypeReference::Named(ident)
286    } else {
287        // Multiple types or contains anonymous → alternation enum
288        let enum_name = format!(
289            "{}{}",
290            parent_name,
291            name_mangler::to_pascal_case(field_name)
292        );
293        TypeReference::Alternation(AlternationEnumDef {
294            type_name: quote::format_ident!("{}", enum_name),
295            variants: types
296                .iter()
297                .map(|tr| make_variant_def(tr, supertype_kinds))
298                .collect(),
299        })
300    }
301}
302
303/// Supertype nodes start with `_` (e.g., `_expression`). Strip the prefix for the type name.
304/// If stripping would conflict with a concrete node name, keep a suffix to disambiguate.
305fn supertype_ident(kind: &str, concrete_kinds: &std::collections::HashSet<&str>) -> Ident {
306    let stripped = kind.strip_prefix('_').unwrap_or(kind);
307    // Check if the stripped name conflicts with a concrete (non-supertype) node
308    if concrete_kinds.contains(stripped) {
309        // Append "Type" to disambiguate (e.g., _parameter → ParameterType, parameter → Parameter)
310        let pascal = name_mangler::to_pascal_case(stripped);
311        quote::format_ident!("{}Type", pascal)
312    } else {
313        name_mangler::type_ident(stripped)
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::codegen::grammar_ir::parse_node_types;
321
322    #[test]
323    fn test_leaf_node_maps_to_leaf_struct() {
324        let nodes = parse_node_types(r#"[{"type": "identifier", "named": true}]"#).unwrap();
325        let decisions = map_types(&nodes);
326        assert_eq!(decisions.len(), 1);
327        assert!(
328            matches!(&decisions[0], TypeDecision::LeafStruct(def) if def.type_name == "Identifier")
329        );
330    }
331
332    #[test]
333    fn test_unnamed_nodes_are_skipped() {
334        let nodes = parse_node_types(r#"[{"type": ".", "named": false}]"#).unwrap();
335        let decisions = map_types(&nodes);
336        assert!(decisions.is_empty());
337    }
338
339    #[test]
340    fn test_node_with_fields_maps_to_struct() {
341        let json = r#"[
342            {"type": "interpreted_string_literal", "named": true},
343            {
344                "type": "import_spec",
345                "named": true,
346                "fields": {
347                    "path": {
348                        "multiple": false,
349                        "required": true,
350                        "types": [{"type": "interpreted_string_literal", "named": true}]
351                    }
352                }
353            }
354        ]"#;
355        let nodes = parse_node_types(json).unwrap();
356        let decisions = map_types(&nodes);
357        assert_eq!(decisions.len(), 2);
358        let TypeDecision::Struct(def) = &decisions[1] else {
359            panic!("expected Struct");
360        };
361        assert_eq!(def.type_name.to_string(), "ImportSpec");
362        assert_eq!(def.fields.len(), 1);
363        assert_eq!(def.fields[0].field_name.to_string(), "path");
364        assert!(matches!(&def.fields[0].field_type, FieldType::Direct(_)));
365    }
366
367    #[test]
368    fn test_optional_field() {
369        let json = r#"[{
370            "type": "import_spec",
371            "named": true,
372            "fields": {
373                "name": {
374                    "multiple": false,
375                    "required": false,
376                    "types": [{"type": "identifier", "named": true}]
377                }
378            }
379        }]"#;
380        let nodes = parse_node_types(json).unwrap();
381        let decisions = map_types(&nodes);
382        let TypeDecision::Struct(def) = &decisions[0] else {
383            panic!("expected Struct");
384        };
385        assert!(matches!(&def.fields[0].field_type, FieldType::Optional(_)));
386    }
387
388    #[test]
389    fn test_repeated_field() {
390        let json = r#"[{
391            "type": "block",
392            "named": true,
393            "fields": {
394                "statements": {
395                    "multiple": true,
396                    "required": false,
397                    "types": [{"type": "statement", "named": true}]
398                }
399            }
400        }]"#;
401        let nodes = parse_node_types(json).unwrap();
402        let decisions = map_types(&nodes);
403        let TypeDecision::Struct(def) = &decisions[0] else {
404            panic!("expected Struct");
405        };
406        assert!(matches!(&def.fields[0].field_type, FieldType::Repeated(_)));
407    }
408
409    #[test]
410    fn test_alternation_field() {
411        let json = r#"[{
412            "type": "import_spec",
413            "named": true,
414            "fields": {
415                "name": {
416                    "multiple": false,
417                    "required": false,
418                    "types": [
419                        {"type": ".", "named": false},
420                        {"type": "identifier", "named": true}
421                    ]
422                }
423            }
424        }]"#;
425        let nodes = parse_node_types(json).unwrap();
426        let decisions = map_types(&nodes);
427        let TypeDecision::Struct(def) = &decisions[0] else {
428            panic!("expected Struct");
429        };
430        let FieldType::Optional(TypeReference::Alternation(alt)) = &def.fields[0].field_type else {
431            panic!("expected Optional(Alternation)");
432        };
433        assert_eq!(alt.type_name.to_string(), "ImportSpecName");
434        assert_eq!(alt.variants.len(), 2);
435        assert_eq!(alt.variants[0].variant_name.to_string(), "Dot");
436        assert_eq!(alt.variants[1].variant_name.to_string(), "Identifier");
437    }
438
439    #[test]
440    fn test_supertype_maps_to_enum() {
441        let json = r#"[{
442            "type": "_expression",
443            "named": true,
444            "subtypes": [
445                {"type": "binary_expression", "named": true},
446                {"type": "identifier", "named": true}
447            ]
448        }]"#;
449        let nodes = parse_node_types(json).unwrap();
450        let decisions = map_types(&nodes);
451        let TypeDecision::SupertypeEnum(def) = &decisions[0] else {
452            panic!("expected SupertypeEnum");
453        };
454        assert_eq!(def.type_name.to_string(), "Expression");
455        assert_eq!(def.variants.len(), 2);
456    }
457
458    #[test]
459    fn test_node_with_children() {
460        let json = r#"[{
461            "type": "import_spec_list",
462            "named": true,
463            "fields": {},
464            "children": {
465                "multiple": true,
466                "required": false,
467                "types": [{"type": "import_spec", "named": true}]
468            }
469        }]"#;
470        let nodes = parse_node_types(json).unwrap();
471        let decisions = map_types(&nodes);
472        let TypeDecision::Struct(def) = &decisions[0] else {
473            panic!("expected Struct");
474        };
475        assert!(def.children.is_some());
476        assert!(matches!(
477            &def.children.as_ref().unwrap().field_type,
478            FieldType::Repeated(_)
479        ));
480    }
481}