tree_sitter_query_formatter/
lib.rs

1use pretty::RcDoc;
2use tree_sitter::{Node, Parser};
3
4fn map_named_node_without_captures<'a>(node: Node<'a>, source: &'a str) -> RcDoc<'a, ()> {
5    let mut docs = Vec::new();
6
7    let mut field_docs = Vec::new();
8    let mut has_fields = false;
9
10    for i in 0..node.child_count() {
11        if let Some(child) = node.child(i) {
12            match child.kind() {
13                "(" => docs.push(RcDoc::text("(")),
14                ")" => {}
15                "field_definition" => {
16                    field_docs.push(map(child, source));
17                    has_fields = true;
18                }
19                "negated_field" => {
20                    field_docs.push(map(child, source));
21                    has_fields = true;
22                }
23                "identifier" => docs.push(map(child, source)),
24                "_" => docs.push(map(child, source)),
25                "named_node" => {
26                    // When a named_node (which has no fields) contains nested named_nodes,
27                    // try to fit on one line, but break if too long
28                    let nested_content = RcDoc::group(RcDoc::concat(vec![RcDoc::nest(
29                        RcDoc::concat(vec![RcDoc::line(), map(child, source)]),
30                        2,
31                    )]));
32                    docs.push(nested_content);
33                }
34                "capture" => {}
35                _ => docs.push(map(child, source)),
36            }
37        }
38    }
39
40    if has_fields {
41        let nested_fields = RcDoc::nest(
42            RcDoc::concat(vec![
43                RcDoc::hardline(),
44                RcDoc::intersperse(field_docs, RcDoc::hardline()),
45            ]),
46            2,
47        );
48        docs.push(nested_fields);
49        docs.push(RcDoc::text(")"));
50        RcDoc::concat(docs)
51    } else {
52        docs.push(RcDoc::text(")"));
53        RcDoc::concat(docs)
54    }
55}
56
57fn map<'a>(node: Node<'a>, source: &'a str) -> RcDoc<'a, ()> {
58    match node.kind() {
59        "program" => {
60            let mut docs = Vec::new();
61            for i in 0..node.child_count() {
62                if let Some(child) = node.child(i) {
63                    docs.push(map(child, source));
64                }
65            }
66            RcDoc::intersperse(docs, RcDoc::line())
67        }
68        "named_node" => {
69            let mut docs = Vec::new();
70
71            let mut field_docs = Vec::new();
72            let mut predicate_docs = Vec::new();
73            let mut has_fields = false;
74            let mut has_capture_after_paren = false;
75
76            for i in 0..node.child_count() {
77                if let Some(child) = node.child(i) {
78                    if child.kind() == ")" {
79                        if i + 1 < node.child_count() {
80                            has_capture_after_paren = true;
81                        }
82                    } else if child.kind() == "field_definition" || child.kind() == "negated_field"
83                    {
84                        has_fields = true;
85                    }
86                }
87            }
88
89            for i in 0..node.child_count() {
90                if let Some(child) = node.child(i) {
91                    match child.kind() {
92                        "(" => docs.push(RcDoc::text("(")),
93                        ")" => {
94                            if has_capture_after_paren {
95                                docs.push(RcDoc::text(")"));
96                            } else {
97                            }
98                        }
99                        "field_definition" => {
100                            field_docs.push(map(child, source));
101                        }
102                        "negated_field" => {
103                            field_docs.push(map(child, source));
104                        }
105                        "identifier" => docs.push(map(child, source)),
106                        "_" => docs.push(map(child, source)),
107                        "named_node" => {
108                            docs.push(RcDoc::text(" "));
109                            docs.push(map(child, source));
110                        }
111                        "capture" => docs.push(map(child, source)),
112                        "quantifier" => docs.push(map(child, source)),
113                        "predicate" => {
114                            if has_fields {
115                                predicate_docs.push(map(child, source));
116                            } else {
117                                docs.push(RcDoc::text(" "));
118                                docs.push(map(child, source));
119                            }
120                        }
121                        _ => docs.push(map(child, source)),
122                    }
123                }
124            }
125
126            if has_fields {
127                let nested_fields = RcDoc::nest(
128                    RcDoc::concat(vec![
129                        RcDoc::hardline(),
130                        RcDoc::intersperse(field_docs, RcDoc::hardline()),
131                        if !predicate_docs.is_empty() {
132                            RcDoc::concat(vec![
133                                RcDoc::hardline(),
134                                RcDoc::intersperse(predicate_docs, RcDoc::hardline()),
135                            ])
136                        } else {
137                            RcDoc::nil()
138                        },
139                    ]),
140                    2,
141                );
142                docs.push(nested_fields);
143                docs.push(RcDoc::text(")"));
144                RcDoc::concat(docs)
145            } else {
146                for predicate_doc in predicate_docs {
147                    docs.push(RcDoc::text(" "));
148                    docs.push(predicate_doc);
149                }
150                if !has_capture_after_paren {
151                    docs.push(RcDoc::text(")"));
152                }
153                RcDoc::concat(docs)
154            }
155        }
156        "field_definition" => {
157            let mut docs = Vec::new();
158
159            if let Some(name_child) = node.child_by_field_name("name") {
160                docs.push(map(name_child, source));
161                docs.push(RcDoc::text(": "));
162            }
163
164            for i in 0..node.child_count() {
165                if let Some(child) = node.child(i) {
166                    if child.kind() == "named_node" {
167                        docs.push(map_named_node_without_captures(child, source));
168                        for j in 0..child.child_count() {
169                            if let Some(capture_child) = child.child(j) {
170                                if capture_child.kind() == "capture" {
171                                    docs.push(map(capture_child, source));
172                                }
173                            }
174                        }
175                    } else if child.kind() != "identifier" && child.kind() != ":" {
176                        docs.push(map(child, source));
177                    }
178                }
179            }
180
181            RcDoc::concat(docs)
182        }
183        "identifier" => {
184            let text = &source[node.start_byte()..node.end_byte()];
185            RcDoc::text(text)
186        }
187        "capture" => {
188            let text = &source[node.start_byte()..node.end_byte()];
189            RcDoc::text(format!(" {}", text))
190        }
191        "anonymous_node" => {
192            let text = &source[node.start_byte()..node.end_byte()];
193            RcDoc::text(text)
194        }
195        "missing_node" => {
196            let mut docs = Vec::new();
197
198            for i in 0..node.child_count() {
199                if let Some(child) = node.child(i) {
200                    match child.kind() {
201                        "(" => docs.push(RcDoc::text("(")),
202                        "MISSING" => {
203                            docs.push(RcDoc::text("MISSING"));
204                            if i + 1 < node.child_count() {
205                                if let Some(next_child) = node.child(i + 1) {
206                                    if next_child.kind() != ")" && next_child.kind() != "capture" {
207                                        docs.push(RcDoc::space());
208                                    }
209                                }
210                            }
211                        }
212                        ")" => docs.push(RcDoc::text(")")),
213                        "capture" => docs.push(map(child, source)),
214                        _ => docs.push(map(child, source)),
215                    }
216                }
217            }
218
219            RcDoc::concat(docs)
220        }
221        "quantifier" => {
222            let text = &source[node.start_byte()..node.end_byte()];
223            RcDoc::text(text)
224        }
225        "grouping" => {
226            let mut docs = Vec::new();
227            let mut child_docs = Vec::new();
228            let mut quantifier_docs = Vec::new();
229            let mut capture_docs = Vec::new();
230
231            for i in 0..node.child_count() {
232                if let Some(child) = node.child(i) {
233                    match child.kind() {
234                        "(" => docs.push(RcDoc::text("(")),
235                        ")" => {}
236                        "named_node" => {
237                            child_docs.push(map(child, source));
238                        }
239                        "anonymous_node" => {
240                            child_docs.push(map(child, source));
241                        }
242                        "predicate" => {
243                            child_docs.push(map(child, source));
244                        }
245                        "." => {
246                            child_docs.push(RcDoc::text("."));
247                        }
248                        "quantifier" => {
249                            quantifier_docs.push(map(child, source));
250                        }
251                        "capture" => {
252                            capture_docs.push(map(child, source));
253                        }
254                        _ => docs.push(map(child, source)),
255                    }
256                }
257            }
258
259            if child_docs.len() > 1 {
260                let content = RcDoc::group(RcDoc::nest(
261                    RcDoc::concat(vec![
262                        RcDoc::line_(),
263                        RcDoc::intersperse(child_docs, RcDoc::line()),
264                        RcDoc::line_(),
265                    ]),
266                    2,
267                ));
268                docs.push(content);
269            } else if child_docs.len() == 1 {
270                docs.push(child_docs.into_iter().next().unwrap());
271            }
272
273            docs.push(RcDoc::text(")"));
274
275            for quantifier_doc in quantifier_docs {
276                docs.push(quantifier_doc);
277            }
278
279            for capture_doc in capture_docs {
280                docs.push(capture_doc);
281            }
282
283            RcDoc::concat(docs)
284        }
285        "list" => {
286            let mut docs = Vec::new();
287            let mut child_docs = Vec::new();
288            let mut captures = Vec::new();
289
290            for i in 0..node.child_count() {
291                if let Some(child) = node.child(i) {
292                    match child.kind() {
293                        "[" => docs.push(RcDoc::text("[")),
294                        "]" => {}
295                        "capture" => captures.push(map(child, source)),
296                        "anonymous_node" => {
297                            child_docs.push(map(child, source));
298                        }
299                        _ => {
300                            child_docs.push(map(child, source));
301                        }
302                    }
303                }
304            }
305
306            if !child_docs.is_empty() {
307                let content = RcDoc::nest(
308                    RcDoc::concat(vec![
309                        RcDoc::hardline(),
310                        RcDoc::intersperse(child_docs, RcDoc::hardline()),
311                        RcDoc::hardline(),
312                    ]),
313                    2,
314                );
315                docs.push(content);
316            }
317
318            docs.push(RcDoc::text("]"));
319
320            for capture in captures {
321                docs.push(capture);
322            }
323
324            RcDoc::concat(docs)
325        }
326        "_" => {
327            let text = &source[node.start_byte()..node.end_byte()];
328            RcDoc::text(text)
329        }
330        "predicate" => {
331            let mut docs = Vec::new();
332
333            for i in 0..node.child_count() {
334                if let Some(child) = node.child(i) {
335                    docs.push(map(child, source));
336                }
337            }
338
339            RcDoc::concat(docs)
340        }
341        "predicate_type" => {
342            let text = &source[node.start_byte()..node.end_byte()];
343            RcDoc::text(text)
344        }
345        "parameters" => {
346            let mut docs = Vec::new();
347            let mut string_docs = Vec::new();
348
349            for i in 0..node.child_count() {
350                if let Some(child) = node.child(i) {
351                    match child.kind() {
352                        "identifier" => {
353                            docs.push(RcDoc::space());
354                            docs.push(map(child, source))
355                        }
356                        "capture" => docs.push(map(child, source)),
357                        "string" => {
358                            string_docs.push(map(child, source));
359                        }
360                        _ => docs.push(map(child, source)),
361                    }
362                }
363            }
364
365            if string_docs.len() > 1 {
366                let formatted_strings = RcDoc::nest(
367                    RcDoc::concat(vec![
368                        RcDoc::hardline(),
369                        RcDoc::intersperse(string_docs, RcDoc::hardline()),
370                    ]),
371                    2,
372                );
373                docs.push(formatted_strings);
374            } else if string_docs.len() == 1 {
375                docs.push(RcDoc::space());
376                docs.push(string_docs.into_iter().next().unwrap());
377            }
378
379            RcDoc::concat(docs)
380        }
381        "string" => {
382            let text = &source[node.start_byte()..node.end_byte()];
383            RcDoc::text(text)
384        }
385        "string_content" => {
386            let text = &source[node.start_byte()..node.end_byte()];
387            RcDoc::text(text)
388        }
389        "negated_field" => {
390            let text = &source[node.start_byte()..node.end_byte()];
391            RcDoc::text(text)
392        }
393        "#" => {
394            let text = &source[node.start_byte()..node.end_byte()];
395            RcDoc::text(text)
396        }
397        "(" => RcDoc::text("("),
398        ")" => RcDoc::text(")"),
399        "." => RcDoc::text(" ."),
400        "/" => RcDoc::text("/"),
401        "comment" => {
402            let text = &source[node.start_byte()..node.end_byte()];
403            RcDoc::text(text)
404        }
405        _ => {
406            println!("Did not handle {}", node.kind());
407            RcDoc::nil()
408        }
409    }
410}
411
412/// Formats a Tree-sitter query string with proper indentation and line breaks.
413///
414/// Takes a Tree-sitter query as input and formats it according to the grammar rules,
415/// applying consistent indentation and line breaking to improve readability.
416///
417/// # Arguments
418///
419/// * `input` - The Tree-sitter query string to format
420/// * `width` - The target line width for formatting
421///
422/// # Returns
423///
424/// Returns a `Result` containing the formatted query string on success, or an error
425/// if parsing or formatting fails.
426///
427/// # Errors
428///
429/// This function will return an error if:
430/// - The Tree-sitter grammar cannot be loaded
431/// - The input query cannot be parsed
432/// - The formatted output cannot be rendered
433///
434/// # Example
435///
436/// ```
437/// use tree_sitter_query_formatter::format;
438///
439/// let query = "(function_definition name: (identifier) @func)";
440/// let formatted = format(query, 80).unwrap();
441/// ```
442pub fn format(input: &str, width: usize) -> Result<String, Box<dyn std::error::Error>> {
443    let mut parser = Parser::new();
444    parser
445        .set_language(&tree_sitter_tsquery::LANGUAGE.into())
446        .map_err(|e| format!("Error loading grammar: {:?}", e))?;
447
448    let tree = parser.parse(input, None).ok_or("Failed to parse input")?;
449
450    let root_node = tree.root_node();
451    let doc = map(root_node, input);
452
453    let mut w = Vec::new();
454    doc.render(width, &mut w)?;
455    let output = String::from_utf8(w)?;
456
457    Ok(output)
458}