treesitter_types/codegen/
type_mapper.rs1use super::grammar_ir::{FieldInfo, NodeType, TypeRef};
2use super::name_mangler;
3use proc_macro2::Ident;
4
5#[derive(Debug, Clone)]
7pub enum TypeDecision {
8 Struct(StructDef),
10 LeafStruct(LeafStructDef),
12 SupertypeEnum(SupertypeEnumDef),
14}
15
16#[derive(Debug, Clone)]
17pub struct StructDef {
18 pub type_name: Ident,
19 pub kind: String,
20 pub fields: Vec<FieldDef>,
21 pub children: Option<ChildrenDef>,
22}
23
24#[derive(Debug, Clone)]
25pub struct LeafStructDef {
26 pub type_name: Ident,
27 pub kind: String,
28}
29
30#[derive(Debug, Clone)]
31pub struct SupertypeEnumDef {
32 pub type_name: Ident,
33 pub kind: String,
34 pub variants: Vec<VariantDef>,
35}
36
37#[derive(Debug, Clone)]
38pub struct FieldDef {
39 pub field_name: Ident,
40 pub raw_field_name: String,
42 pub field_type: FieldType,
43}
44
45#[derive(Debug, Clone)]
46pub struct ChildrenDef {
47 pub field_type: FieldType,
48}
49
50#[derive(Debug, Clone)]
52pub enum FieldType {
53 Direct(TypeReference),
55 Optional(TypeReference),
57 Repeated(TypeReference),
59}
60
61#[derive(Debug, Clone)]
63pub enum TypeReference {
64 Named(Ident),
66 Alternation(AlternationEnumDef),
68}
69
70#[derive(Debug, Clone)]
71pub struct AlternationEnumDef {
72 pub type_name: Ident,
73 pub variants: Vec<VariantDef>,
74}
75
76#[derive(Debug, Clone)]
77pub struct VariantDef {
78 pub variant_name: Ident,
79 pub kind: String,
80 pub named: bool,
81 pub is_supertype: bool,
84}
85
86pub fn map_types(nodes: &[NodeType]) -> Vec<TypeDecision> {
88 let supertype_kinds: std::collections::HashSet<&str> = nodes
90 .iter()
91 .filter(|n| n.subtypes.is_some())
92 .map(|n| n.type_name.as_str())
93 .collect();
94
95 nodes
96 .iter()
97 .filter(|n| n.named)
98 .map(|n| map_node(n, &supertype_kinds))
99 .collect()
100}
101
102fn map_node(node: &NodeType, supertype_kinds: &std::collections::HashSet<&str>) -> TypeDecision {
103 let raw_kind = &node.type_name;
104
105 if let Some(subtypes) = &node.subtypes {
107 return TypeDecision::SupertypeEnum(SupertypeEnumDef {
108 type_name: supertype_ident(raw_kind),
109 kind: raw_kind.clone(),
110 variants: subtypes
111 .iter()
112 .map(|tr| make_variant_def(tr, supertype_kinds))
113 .collect(),
114 });
115 }
116
117 if !node.fields.is_empty() || node.children.is_some() {
119 let type_name = name_mangler::type_ident(raw_kind);
120 let fields = node
121 .fields
122 .iter()
123 .map(|(field_name, field_info)| {
124 let parent_name = type_name.to_string();
125 map_field(field_name, field_info, &parent_name, supertype_kinds)
126 })
127 .collect();
128 let children = node.children.as_ref().map(|c| {
129 let parent_name = type_name.to_string();
130 map_children(c, &parent_name, supertype_kinds)
131 });
132 return TypeDecision::Struct(StructDef {
133 type_name,
134 kind: raw_kind.clone(),
135 fields,
136 children,
137 });
138 }
139
140 TypeDecision::LeafStruct(LeafStructDef {
142 type_name: name_mangler::type_ident(raw_kind),
143 kind: raw_kind.clone(),
144 })
145}
146
147fn make_variant_def(tr: &TypeRef, supertype_kinds: &std::collections::HashSet<&str>) -> VariantDef {
148 VariantDef {
149 variant_name: name_mangler::variant_name(&tr.type_name, tr.named),
150 kind: tr.type_name.clone(),
151 named: tr.named,
152 is_supertype: supertype_kinds.contains(tr.type_name.as_str()),
153 }
154}
155
156fn map_field(
157 field_name: &str,
158 field_info: &FieldInfo,
159 parent_name: &str,
160 supertype_kinds: &std::collections::HashSet<&str>,
161) -> FieldDef {
162 let type_ref = map_type_reference(&field_info.types, parent_name, field_name, supertype_kinds);
163 let field_type = match (field_info.required, field_info.multiple) {
164 (_, true) => FieldType::Repeated(type_ref),
165 (false, false) => FieldType::Optional(type_ref),
166 (true, false) => FieldType::Direct(type_ref),
167 };
168
169 FieldDef {
170 field_name: name_mangler::field_ident(field_name),
171 raw_field_name: field_name.to_owned(),
172 field_type,
173 }
174}
175
176fn map_children(
177 children: &FieldInfo,
178 parent_name: &str,
179 supertype_kinds: &std::collections::HashSet<&str>,
180) -> ChildrenDef {
181 let type_ref = map_type_reference(&children.types, parent_name, "children", supertype_kinds);
182 let field_type = match (children.required, children.multiple) {
183 (_, true) => FieldType::Repeated(type_ref),
184 (false, false) => FieldType::Optional(type_ref),
185 (true, false) => FieldType::Direct(type_ref),
186 };
187 ChildrenDef { field_type }
188}
189
190fn map_type_reference(
191 types: &[TypeRef],
192 parent_name: &str,
193 field_name: &str,
194 supertype_kinds: &std::collections::HashSet<&str>,
195) -> TypeReference {
196 if types.len() == 1 && types[0].named {
199 TypeReference::Named(name_mangler::type_ident(&types[0].type_name))
200 } else {
201 let enum_name = format!(
203 "{}{}",
204 parent_name,
205 name_mangler::to_pascal_case(field_name)
206 );
207 TypeReference::Alternation(AlternationEnumDef {
208 type_name: quote::format_ident!("{}", enum_name),
209 variants: types
210 .iter()
211 .map(|tr| make_variant_def(tr, supertype_kinds))
212 .collect(),
213 })
214 }
215}
216
217fn supertype_ident(kind: &str) -> Ident {
219 let stripped = kind.strip_prefix('_').unwrap_or(kind);
220 name_mangler::type_ident(stripped)
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use crate::codegen::grammar_ir::parse_node_types;
227
228 #[test]
229 fn test_leaf_node_maps_to_leaf_struct() {
230 let nodes = parse_node_types(r#"[{"type": "identifier", "named": true}]"#).unwrap();
231 let decisions = map_types(&nodes);
232 assert_eq!(decisions.len(), 1);
233 assert!(
234 matches!(&decisions[0], TypeDecision::LeafStruct(def) if def.type_name == "Identifier")
235 );
236 }
237
238 #[test]
239 fn test_unnamed_nodes_are_skipped() {
240 let nodes = parse_node_types(r#"[{"type": ".", "named": false}]"#).unwrap();
241 let decisions = map_types(&nodes);
242 assert!(decisions.is_empty());
243 }
244
245 #[test]
246 fn test_node_with_fields_maps_to_struct() {
247 let json = r#"[{
248 "type": "import_spec",
249 "named": true,
250 "fields": {
251 "path": {
252 "multiple": false,
253 "required": true,
254 "types": [{"type": "interpreted_string_literal", "named": true}]
255 }
256 }
257 }]"#;
258 let nodes = parse_node_types(json).unwrap();
259 let decisions = map_types(&nodes);
260 assert_eq!(decisions.len(), 1);
261 let TypeDecision::Struct(def) = &decisions[0] else {
262 panic!("expected Struct");
263 };
264 assert_eq!(def.type_name.to_string(), "ImportSpec");
265 assert_eq!(def.fields.len(), 1);
266 assert_eq!(def.fields[0].field_name.to_string(), "path");
267 assert!(matches!(&def.fields[0].field_type, FieldType::Direct(_)));
268 }
269
270 #[test]
271 fn test_optional_field() {
272 let json = r#"[{
273 "type": "import_spec",
274 "named": true,
275 "fields": {
276 "name": {
277 "multiple": false,
278 "required": false,
279 "types": [{"type": "identifier", "named": true}]
280 }
281 }
282 }]"#;
283 let nodes = parse_node_types(json).unwrap();
284 let decisions = map_types(&nodes);
285 let TypeDecision::Struct(def) = &decisions[0] else {
286 panic!("expected Struct");
287 };
288 assert!(matches!(&def.fields[0].field_type, FieldType::Optional(_)));
289 }
290
291 #[test]
292 fn test_repeated_field() {
293 let json = r#"[{
294 "type": "block",
295 "named": true,
296 "fields": {
297 "statements": {
298 "multiple": true,
299 "required": false,
300 "types": [{"type": "statement", "named": true}]
301 }
302 }
303 }]"#;
304 let nodes = parse_node_types(json).unwrap();
305 let decisions = map_types(&nodes);
306 let TypeDecision::Struct(def) = &decisions[0] else {
307 panic!("expected Struct");
308 };
309 assert!(matches!(&def.fields[0].field_type, FieldType::Repeated(_)));
310 }
311
312 #[test]
313 fn test_alternation_field() {
314 let json = r#"[{
315 "type": "import_spec",
316 "named": true,
317 "fields": {
318 "name": {
319 "multiple": false,
320 "required": false,
321 "types": [
322 {"type": ".", "named": false},
323 {"type": "identifier", "named": true}
324 ]
325 }
326 }
327 }]"#;
328 let nodes = parse_node_types(json).unwrap();
329 let decisions = map_types(&nodes);
330 let TypeDecision::Struct(def) = &decisions[0] else {
331 panic!("expected Struct");
332 };
333 let FieldType::Optional(TypeReference::Alternation(alt)) = &def.fields[0].field_type else {
334 panic!("expected Optional(Alternation)");
335 };
336 assert_eq!(alt.type_name.to_string(), "ImportSpecName");
337 assert_eq!(alt.variants.len(), 2);
338 assert_eq!(alt.variants[0].variant_name.to_string(), "Dot");
339 assert_eq!(alt.variants[1].variant_name.to_string(), "Identifier");
340 }
341
342 #[test]
343 fn test_supertype_maps_to_enum() {
344 let json = r#"[{
345 "type": "_expression",
346 "named": true,
347 "subtypes": [
348 {"type": "binary_expression", "named": true},
349 {"type": "identifier", "named": true}
350 ]
351 }]"#;
352 let nodes = parse_node_types(json).unwrap();
353 let decisions = map_types(&nodes);
354 let TypeDecision::SupertypeEnum(def) = &decisions[0] else {
355 panic!("expected SupertypeEnum");
356 };
357 assert_eq!(def.type_name.to_string(), "Expression");
358 assert_eq!(def.variants.len(), 2);
359 }
360
361 #[test]
362 fn test_node_with_children() {
363 let json = r#"[{
364 "type": "import_spec_list",
365 "named": true,
366 "fields": {},
367 "children": {
368 "multiple": true,
369 "required": false,
370 "types": [{"type": "import_spec", "named": true}]
371 }
372 }]"#;
373 let nodes = parse_node_types(json).unwrap();
374 let decisions = map_types(&nodes);
375 let TypeDecision::Struct(def) = &decisions[0] else {
376 panic!("expected Struct");
377 };
378 assert!(def.children.is_some());
379 assert!(matches!(
380 &def.children.as_ref().unwrap().field_type,
381 FieldType::Repeated(_)
382 ));
383 }
384}