topiary_core/
tree_sitter.rs

1// WASM build doesn't use topiary_tree_sitter_facade::QueryMatch or
2// streaming_iterator::StreamingIterator
3#![cfg_attr(target_arch = "wasm32", allow(unused_imports))]
4
5use std::{collections::HashSet, fmt::Display};
6
7use serde::Serialize;
8
9use topiary_tree_sitter_facade::{
10    Node, Parser, Point, Query, QueryCapture, QueryCursor, QueryMatch, QueryPredicate, Tree,
11};
12
13use streaming_iterator::StreamingIterator;
14
15use crate::{
16    atom_collection::{AtomCollection, QueryPredicates},
17    error::FormatterError,
18    FormatterResult,
19};
20
21/// Supported visualisation formats
22#[derive(Clone, Copy, Debug)]
23pub enum Visualisation {
24    GraphViz,
25    Json,
26}
27
28/// Refers to a position within the code. Used for error reporting, and for
29/// comparing input with formatted output. The numbers are 1-based, because that
30/// is how editors usually refer to a position. Derived from tree_sitter::Point.
31#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)]
32pub struct Position {
33    pub row: u32,
34    pub column: u32,
35}
36
37impl Display for Position {
38    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
39        write!(f, "({},{})", self.row, self.column)
40    }
41}
42
43/// Topiary often needs both the tree-sitter `Query` and the original content
44/// belonging to the file from which the query was parsed. This struct is a simple
45/// convenience wrapper that combines the `Query` with its original string.
46#[derive(Debug)]
47pub struct TopiaryQuery {
48    pub query: Query,
49    pub query_content: String,
50}
51
52impl TopiaryQuery {
53    /// Creates a new `TopiaryQuery` from a tree-sitter language/grammar and the
54    /// contents of the query file.
55    ///
56    /// # Errors
57    ///
58    /// This function will return an error if tree-sitter failed to parse the
59    /// query file.
60    pub fn new(
61        grammar: &topiary_tree_sitter_facade::Language,
62        query_content: &str,
63    ) -> FormatterResult<TopiaryQuery> {
64        let query = Query::new(grammar, query_content)
65            .map_err(|e| FormatterError::Query("Error parsing query file".into(), Some(e)))?;
66
67        Ok(TopiaryQuery {
68            query,
69            query_content: query_content.to_owned(),
70        })
71    }
72
73    /// Calculates the provided position of the Pattern in the query source file
74    /// from the byte offset of the pattern in the query.
75    #[cfg(not(target_arch = "wasm32"))]
76    pub fn pattern_position(&self, pattern_index: usize) -> Position {
77        let byte_offset = self.query.start_byte_for_pattern(pattern_index);
78        let (row, column) =
79            self.query_content[..byte_offset]
80                .chars()
81                .fold((0, 0), |(row, column), c| {
82                    if c == '\n' {
83                        (row + 1, 0)
84                    } else {
85                        (row, column + 1)
86                    }
87                });
88        Position {
89            row: row + 1,
90            column: column + 1,
91        }
92    }
93
94    #[cfg(target_arch = "wasm32")]
95    pub fn pattern_position(&self, _pattern_index: usize) -> Position {
96        unimplemented!()
97    }
98}
99
100impl From<Point> for Position {
101    fn from(point: Point) -> Self {
102        Self {
103            row: point.row() + 1,
104            column: point.column() + 1,
105        }
106    }
107}
108
109// Simplified syntactic node struct, for the sake of serialisation.
110#[derive(Serialize)]
111pub struct SyntaxNode {
112    #[serde(skip_serializing)]
113    pub id: usize,
114
115    pub kind: String,
116    pub is_named: bool,
117    is_extra: bool,
118    is_error: bool,
119    is_missing: bool,
120    start: Position,
121    end: Position,
122
123    pub children: Vec<SyntaxNode>,
124}
125
126impl From<Node<'_>> for SyntaxNode {
127    fn from(node: Node) -> Self {
128        let mut walker = node.walk();
129        let children = node.children(&mut walker).map(Self::from).collect();
130
131        Self {
132            id: node.id(),
133
134            kind: node.kind().into(),
135            is_named: node.is_named(),
136            is_extra: node.is_extra(),
137            is_error: node.is_error(),
138            is_missing: node.is_missing(),
139            start: node.start_position().into(),
140            end: node.end_position().into(),
141
142            children,
143        }
144    }
145}
146
147/// Extension trait for [`Node`] to allow for 1-based display in logs.
148///
149/// (Can't be done as a [`Display`] impl on [`Node`] directly, since that would
150/// run into orphan issues. An alternative that would work is a [`Display`] impl
151/// on a wrapper struct.)
152pub trait NodeExt {
153    /// Produce a textual representation with 1-based row/column indexes.
154    fn display_one_based(&self) -> String;
155}
156
157impl NodeExt for Node<'_> {
158    fn display_one_based(&self) -> String {
159        format!(
160            "{{Node {:?} {} - {}}}",
161            self.kind(),
162            Position::from(self.start_position()),
163            Position::from(self.end_position()),
164        )
165    }
166}
167
168#[cfg(not(target_arch = "wasm32"))]
169impl NodeExt for tree_sitter::Node<'_> {
170    fn display_one_based(&self) -> String {
171        format!(
172            "{{Node {:?} {} - {}}}",
173            self.kind(),
174            Position::from(<tree_sitter::Point as Into<Point>>::into(
175                self.start_position()
176            )),
177            Position::from(<tree_sitter::Point as Into<Point>>::into(
178                self.end_position()
179            )),
180        )
181    }
182}
183
184#[derive(Debug)]
185// A struct to statically store the public fields of query match results,
186// to avoid running queries twice.
187struct LocalQueryMatch<'a> {
188    pattern_index: usize,
189    captures: Vec<QueryCapture<'a>>,
190}
191
192impl Display for LocalQueryMatch<'_> {
193    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
194        write!(
195            f,
196            "LocalQueryMatch {{ pattern_index: {}, captures: [ ",
197            self.pattern_index
198        )?;
199        for (index, capture) in self.captures.iter().enumerate() {
200            if index > 0 {
201                write!(f, ", ")?;
202            }
203            // .node() doesn't provide access to the inner [`tree_sitter`]
204            // object. As a result, we can't get the index out directly, so we
205            // skip it for now.
206            write!(f, "{}", capture.node().display_one_based())?;
207        }
208        write!(f, " ] }}")?;
209        Ok(())
210    }
211}
212
213#[derive(Clone, Debug, PartialEq)]
214// A struct to store the result of a query coverage check
215pub struct CoverageData {
216    pub cover_percentage: f32,
217    pub missing_patterns: Vec<String>,
218}
219
220/// Applies a query to an input content and returns a collection of atoms.
221///
222/// # Errors
223///
224/// This function can return an error if:
225/// - The input content cannot be parsed by the grammar.
226/// - The query content cannot be parsed by the grammar.
227/// - The input exhaustivity check fails.
228/// - A found predicate could not be parsed or is malformed.
229/// - A unknown capture name was encountered in the query.
230pub fn apply_query(
231    input_content: &str,
232    query: &TopiaryQuery,
233    grammar: &topiary_tree_sitter_facade::Language,
234    tolerate_parsing_errors: bool,
235) -> FormatterResult<AtomCollection> {
236    let tree = parse(input_content, grammar, tolerate_parsing_errors)?;
237    let root = tree.root_node();
238    let source = input_content.as_bytes();
239
240    // Match queries
241    let mut cursor = QueryCursor::new();
242    let mut matches: Vec<LocalQueryMatch> = Vec::new();
243    let capture_names = query.query.capture_names();
244
245    let mut query_matches = query.query.matches(&root, source, &mut cursor);
246    #[allow(clippy::while_let_on_iterator)] // This is not a normal iterator
247    while let Some(query_match) = query_matches.next() {
248        let local_captures: Vec<QueryCapture> = query_match.captures().collect();
249
250        matches.push(LocalQueryMatch {
251            pattern_index: query_match.pattern_index(),
252            captures: local_captures,
253        });
254    }
255
256    // Find the ids of all tree-sitter nodes that were identified as a leaf
257    // We want to avoid recursing into them in the collect_leaves function.
258    let specified_leaf_nodes: HashSet<usize> = collect_leaf_ids(&matches, capture_names.clone());
259
260    // The Flattening: collects all terminal nodes of the tree-sitter tree in a Vec
261    let mut atoms = AtomCollection::collect_leaves(&root, source, specified_leaf_nodes)?;
262
263    log::debug!("List of atoms before formatting: {atoms:?}");
264
265    // Memoization of the pattern positions
266    let mut pattern_positions: Vec<Option<Position>> = Vec::new();
267
268    // The web bindings for tree-sitter do not have support for pattern_count, so instead we will resize as needed
269    // Only reallocate if we are actually going to use the vec
270    #[cfg(not(target_arch = "wasm32"))]
271    if log::log_enabled!(log::Level::Info) {
272        pattern_positions.resize(query.query.pattern_count(), None);
273    }
274
275    // If there are more than one capture per match, it generally means that we
276    // want to use the last capture. For example
277    // (
278    //   (enum_item) @append_hardline .
279    //   (line_comment)? @append_hardline
280    // )
281    // means we want to append a hardline at
282    // the end, but we don't know if we get a line_comment capture or not.
283    for m in matches {
284        let mut predicates = QueryPredicates::default();
285
286        for p in query.query.general_predicates(m.pattern_index) {
287            predicates = handle_predicate(&p, &predicates)?;
288        }
289        check_predicates(&predicates)?;
290
291        // NOTE: Only performed if logging is enabled to avoid unnecessary computation of Position
292        if log::log_enabled!(log::Level::Info) {
293            #[cfg(target_arch = "wasm32")]
294            // Resize the pattern_positions vector if we need to store more positions
295            if m.pattern_index >= pattern_positions.len() {
296                pattern_positions.resize(m.pattern_index + 1, None);
297            }
298
299            // Fetch from pattern_positions, otherwise insert
300            let pos = pattern_positions[m.pattern_index].unwrap_or_else(|| {
301                let pos = query.pattern_position(m.pattern_index);
302                pattern_positions[m.pattern_index] = Some(pos);
303                pos
304            });
305
306            let query_name_info = if let Some(name) = &predicates.query_name {
307                format!(" of query \"{name}\"")
308            } else {
309                "".into()
310            };
311
312            log::info!("Processing match{query_name_info}: {m} at location {pos}");
313        }
314
315        // If any capture is a do_nothing, then do nothing.
316        if m.captures
317            .iter()
318            .any(|c| c.name(capture_names.as_slice()) == "do_nothing")
319        {
320            continue;
321        }
322
323        for c in m.captures {
324            let name = c.name(capture_names.as_slice());
325            atoms.resolve_capture(&name, &c.node(), &predicates)?;
326        }
327    }
328
329    // Now apply all atoms in prepend and append to the leaf nodes.
330    atoms.apply_prepends_and_appends();
331
332    Ok(atoms)
333}
334
335/// Parses some string into a syntax tree, given a tree-sitter grammar.
336pub fn parse(
337    content: &str,
338    grammar: &topiary_tree_sitter_facade::Language,
339    tolerate_parsing_errors: bool,
340) -> FormatterResult<Tree> {
341    let mut parser = Parser::new()?;
342    parser.set_language(grammar).map_err(|_| {
343        FormatterError::Internal("Could not apply Tree-sitter grammar".into(), None)
344    })?;
345
346    let tree = parser
347        .parse(content, None)?
348        .ok_or_else(|| FormatterError::Internal("Could not parse input".into(), None))?;
349
350    // Fail parsing if we don't get a complete syntax tree.
351    if !tolerate_parsing_errors {
352        check_for_error_nodes(&tree.root_node())?;
353    }
354
355    Ok(tree)
356}
357
358fn check_for_error_nodes(node: &Node) -> FormatterResult<()> {
359    if node.kind() == "ERROR" {
360        let start = node.start_position();
361        let end = node.end_position();
362
363        // Report 1-based lines and columns.
364        return Err(FormatterError::Parsing {
365            start_line: start.row() + 1,
366            start_column: start.column() + 1,
367            end_line: end.row() + 1,
368            end_column: end.column() + 1,
369        });
370    }
371
372    for child in node.children(&mut node.walk()) {
373        check_for_error_nodes(&child)?;
374    }
375
376    Ok(())
377}
378
379/// Collects the IDs of all leaf nodes in a set of query matches.
380///
381/// This function takes a slice of `LocalQueryMatch` and a slice of capture names,
382/// and returns a `HashSet` of node IDs that are matched by the "leaf" capture name.
383fn collect_leaf_ids(matches: &[LocalQueryMatch], capture_names: Vec<&str>) -> HashSet<usize> {
384    let mut ids = HashSet::new();
385
386    for m in matches {
387        for c in &m.captures {
388            if c.name(capture_names.as_slice()) == "leaf" {
389                ids.insert(c.node().id());
390            }
391        }
392    }
393    ids
394}
395
396/// Handles a query predicate and returns a new set of query predicates with the corresponding field updated.
397///
398/// # Arguments
399///
400/// * `predicate` - A reference to a `QueryPredicate` object that represents a predicate in a query pattern.
401/// * `predicates` - A reference to a `QueryPredicates` object that holds the current state of the query predicates.
402///
403/// # Returns
404///
405/// A `FormatterResult` that contains either a new `QueryPredicates` object with the updated field, or a `FormatterError` if the predicate is invalid or missing an argument.
406///
407/// # Errors
408///
409/// This function will return an error if:
410///
411/// * The predicate operator is not one of the supported ones.
412/// * The predicate operator requires an argument but none is provided.
413fn handle_predicate(
414    predicate: &QueryPredicate,
415    predicates: &QueryPredicates,
416) -> FormatterResult<QueryPredicates> {
417    let operator = &*predicate.operator();
418    if "delimiter!" == operator {
419        let arg =
420            predicate.args().into_iter().next().ok_or_else(|| {
421                FormatterError::Query(format!("{operator} needs an argument"), None)
422            })?;
423        Ok(QueryPredicates {
424            delimiter: Some(arg),
425            ..predicates.clone()
426        })
427    } else if "scope_id!" == operator {
428        let arg =
429            predicate.args().into_iter().next().ok_or_else(|| {
430                FormatterError::Query(format!("{operator} needs an argument"), None)
431            })?;
432        Ok(QueryPredicates {
433            scope_id: Some(arg),
434            ..predicates.clone()
435        })
436    } else if "single_line_only!" == operator {
437        Ok(QueryPredicates {
438            single_line_only: true,
439            ..predicates.clone()
440        })
441    } else if "multi_line_only!" == operator {
442        Ok(QueryPredicates {
443            multi_line_only: true,
444            ..predicates.clone()
445        })
446    } else if "single_line_scope_only!" == operator {
447        let arg =
448            predicate.args().into_iter().next().ok_or_else(|| {
449                FormatterError::Query(format!("{operator} needs an argument"), None)
450            })?;
451        Ok(QueryPredicates {
452            single_line_scope_only: Some(arg),
453            ..predicates.clone()
454        })
455    } else if "multi_line_scope_only!" == operator {
456        let arg =
457            predicate.args().into_iter().next().ok_or_else(|| {
458                FormatterError::Query(format!("{operator} needs an argument"), None)
459            })?;
460        Ok(QueryPredicates {
461            multi_line_scope_only: Some(arg),
462            ..predicates.clone()
463        })
464    } else if "query_name!" == operator {
465        let arg =
466            predicate.args().into_iter().next().ok_or_else(|| {
467                FormatterError::Query(format!("{operator} needs an argument"), None)
468            })?;
469        Ok(QueryPredicates {
470            query_name: Some(arg),
471            ..predicates.clone()
472        })
473    } else {
474        Err(FormatterError::Query(
475            format!("{operator} is an unknown predicate. Maybe you forgot a \"!\"?"),
476            None,
477        ))
478    }
479}
480
481/// Checks the validity of the query predicates.
482///
483/// This function ensures that the query predicates do not contain more than one
484/// of the following: #single_line_only, #multi_line_only, #single_line_scope_only,
485/// or #multi_line_scope_only. These predicates are incompatible with each other
486/// and would result in an invalid query.
487///
488/// # Arguments
489///
490/// * `predicates` - A reference to a QueryPredicates struct that holds the query predicates.
491///
492/// # Errors
493///
494/// If the query predicates contain more than one incompatible predicate, this function
495/// returns a FormatterError::Query with a descriptive message.
496fn check_predicates(predicates: &QueryPredicates) -> FormatterResult<()> {
497    let mut incompatible_predicates = 0;
498    if predicates.single_line_only {
499        incompatible_predicates += 1;
500    }
501    if predicates.multi_line_only {
502        incompatible_predicates += 1;
503    }
504    if predicates.single_line_scope_only.is_some() {
505        incompatible_predicates += 1;
506    }
507    if predicates.multi_line_scope_only.is_some() {
508        incompatible_predicates += 1;
509    }
510    if incompatible_predicates > 1 {
511        Err(FormatterError::Query(
512            "A query can contain at most one #single/multi_line[_scope]_only! predicate".into(),
513            None,
514        ))
515    } else {
516        Ok(())
517    }
518}
519
520#[cfg(not(target_arch = "wasm32"))]
521/// Check if the input tests all patterns in the query, by successively disabling
522/// all patterns. If disabling a pattern does not decrease the number of matches,
523/// then that pattern originally matched nothing in the input.
524pub fn check_query_coverage(
525    input_content: &str,
526    original_query: &TopiaryQuery,
527    grammar: &topiary_tree_sitter_facade::Language,
528) -> FormatterResult<CoverageData> {
529    let tree = parse(input_content, grammar, false)?;
530    let root = tree.root_node();
531    let source = input_content.as_bytes();
532    let mut missing_patterns = Vec::new();
533
534    // Match queries
535    let mut cursor = QueryCursor::new();
536    let ref_match_count = original_query
537        .query
538        .matches(&root, source, &mut cursor)
539        .count();
540    let pattern_count = original_query.query.pattern_count();
541    let query_content = &original_query.query_content;
542
543    // If there are no queries at all (e.g., when debugging) return early
544    // rather than dividing by zero
545    if pattern_count == 0 {
546        let cover_percentage = 0.0;
547        return Ok(CoverageData {
548            cover_percentage,
549            missing_patterns,
550        });
551    }
552
553    // This particular test avoids a SIGSEGV error that occurs when trying
554    // to count the matches of an empty query (see #481)
555    if pattern_count == 1 {
556        let mut cover_percentage = 1.0;
557        if ref_match_count == 0 {
558            missing_patterns.push(query_content.into());
559            cover_percentage = 0.0
560        }
561        return Ok(CoverageData {
562            cover_percentage,
563            missing_patterns,
564        });
565    }
566
567    let mut ok_patterns = 0.0;
568    for i in 0..pattern_count {
569        // We don't need to use TopiaryQuery in this test since we have no need
570        // for duplicate versions of the query_content string, instead we create the query
571        // manually.
572        let mut query = Query::new(grammar, query_content)
573            .map_err(|e| FormatterError::Query("Error parsing query file".into(), Some(e)))?;
574        query.disable_pattern(i);
575        let mut cursor = QueryCursor::new();
576        let match_count = query.matches(&root, source, &mut cursor).count();
577        if match_count == ref_match_count {
578            let index_start = query.start_byte_for_pattern(i);
579            let index_end = if i == pattern_count - 1 {
580                query_content.len()
581            } else {
582                query.start_byte_for_pattern(i + 1)
583            };
584            let pattern_content = &query_content[index_start..index_end];
585            missing_patterns.push(pattern_content.into());
586        } else {
587            ok_patterns += 1.0;
588        }
589    }
590
591    let cover_percentage = ok_patterns / pattern_count as f32;
592    Ok(CoverageData {
593        cover_percentage,
594        missing_patterns,
595    })
596}
597
598#[cfg(target_arch = "wasm32")]
599pub fn check_query_coverage(
600    _input_content: &str,
601    _original_query: &TopiaryQuery,
602    _grammar: &topiary_tree_sitter_facade::Language,
603) -> FormatterResult<CoverageData> {
604    unimplemented!();
605}