tree_house/
highlighter.rs

1use std::borrow::Cow;
2use std::cmp;
3use std::fmt;
4use std::mem::replace;
5use std::num::NonZeroU32;
6use std::ops::RangeBounds;
7use std::slice;
8use std::sync::Arc;
9
10use crate::config::{LanguageConfig, LanguageLoader};
11use crate::locals::ScopeCursor;
12use crate::query_iter::{MatchedNode, QueryIter, QueryIterEvent, QueryLoader};
13use crate::{Injection, Language, Layer, Syntax};
14use arc_swap::ArcSwap;
15use hashbrown::{HashMap, HashSet};
16use ropey::RopeSlice;
17use tree_sitter::{
18    query::{self, InvalidPredicateError, Query, UserPredicate},
19    Capture, Grammar,
20};
21use tree_sitter::{Pattern, QueryMatch};
22
23/// Contains the data needed to highlight code written in a particular language.
24///
25/// This struct is immutable and can be shared between threads.
26#[derive(Debug)]
27pub struct HighlightQuery {
28    pub query: Query,
29    highlight_indices: ArcSwap<Vec<Option<Highlight>>>,
30    #[allow(dead_code)]
31    /// Patterns that do not match when the node is a local.
32    non_local_patterns: HashSet<Pattern>,
33    local_reference_capture: Option<Capture>,
34}
35
36impl HighlightQuery {
37    pub(crate) fn new(
38        grammar: Grammar,
39        highlight_query_text: &str,
40        local_query_text: &str,
41    ) -> Result<Self, query::ParseError> {
42        // Concatenate the highlights and locals queries.
43        let mut query_source =
44            String::with_capacity(highlight_query_text.len() + local_query_text.len());
45        query_source.push_str(highlight_query_text);
46        query_source.push_str(local_query_text);
47
48        let mut non_local_patterns = HashSet::new();
49        let mut query = Query::new(grammar, &query_source, |pattern, predicate| {
50            match predicate {
51                // Allow the `(#set! local.scope-inherits <bool>)` property to be parsed.
52                // This information is not used by this query though, it's used in the
53                // injection query instead.
54                UserPredicate::SetProperty {
55                    key: "local.scope-inherits",
56                    ..
57                } => (),
58                // TODO: `(#is(-not)? local)` applies to the entire pattern. Ideally you
59                // should be able to supply capture(s?) which are each checked.
60                UserPredicate::IsPropertySet {
61                    negate: true,
62                    key: "local",
63                    val: None,
64                } => {
65                    non_local_patterns.insert(pattern);
66                }
67                _ => return Err(InvalidPredicateError::unknown(predicate)),
68            }
69            Ok(())
70        })?;
71
72        // The highlight query only cares about local.reference captures. All scope and definition
73        // captures can be disabled.
74        query.disable_capture("local.scope");
75        let local_definition_captures: Vec<_> = query
76            .captures()
77            .filter(|&(_, name)| name.starts_with("local.definition."))
78            .map(|(_, name)| Box::<str>::from(name))
79            .collect();
80        for name in local_definition_captures {
81            query.disable_capture(&name);
82        }
83
84        Ok(Self {
85            highlight_indices: ArcSwap::from_pointee(vec![None; query.num_captures() as usize]),
86            non_local_patterns,
87            local_reference_capture: query.get_capture("local.reference"),
88            query,
89        })
90    }
91
92    /// Configures the list of recognized highlight names.
93    ///
94    /// Tree-sitter syntax-highlighting queries specify highlights in the form of dot-separated
95    /// highlight names like `punctuation.bracket` and `function.method.builtin`. Consumers of
96    /// these queries can choose to recognize highlights with different levels of specificity.
97    /// For example, the string `function.builtin` will match against `function.builtin.constructor`
98    /// but will not match `function.method.builtin` and `function.method`.
99    ///
100    /// The closure provided to this function should therefore try to first lookup the full
101    /// name. If no highlight was found for that name it should [`rsplit_once('.')`](str::rsplit_once)
102    /// and retry until a highlight has been found. If none of the parent scopes are defined
103    /// then `Highlight::NONE` should be returned.
104    ///
105    /// When highlighting, results are returned as `Highlight` values, configured by this function.
106    /// The meaning of these indices is up to the user of the implementation. The highlighter
107    /// treats the indices as entirely opaque.
108    pub(crate) fn configure(&self, f: &mut impl FnMut(&str) -> Option<Highlight>) {
109        let highlight_indices = self
110            .query
111            .captures()
112            .map(|(_, capture_name)| f(capture_name))
113            .collect();
114        self.highlight_indices.store(Arc::new(highlight_indices));
115    }
116}
117
118/// Indicates which highlight should be applied to a region of source code.
119///
120/// This type is represented as a non-max u32 - a u32 which cannot be `u32::MAX`. This is checked
121/// at runtime with assertions in `Highlight::new`.
122#[derive(Copy, Clone, PartialEq, Eq)]
123pub struct Highlight(NonZeroU32);
124
125impl Highlight {
126    pub const MAX: u32 = u32::MAX - 1;
127
128    pub const fn new(inner: u32) -> Self {
129        assert!(inner != u32::MAX);
130        // SAFETY: must be non-zero because `inner` is not `u32::MAX`.
131        Self(unsafe { NonZeroU32::new_unchecked(inner ^ u32::MAX) })
132    }
133
134    pub const fn get(&self) -> u32 {
135        self.0.get() ^ u32::MAX
136    }
137
138    pub const fn idx(&self) -> usize {
139        self.get() as usize
140    }
141}
142
143impl fmt::Debug for Highlight {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        f.debug_tuple("Highlight").field(&self.get()).finish()
146    }
147}
148
149#[derive(Debug)]
150struct HighlightedNode {
151    end: u32,
152    highlight: Highlight,
153}
154
155#[derive(Debug, Default)]
156pub struct LayerData {
157    parent_highlights: usize,
158    dormant_highlights: Vec<HighlightedNode>,
159}
160
161pub struct Highlighter<'a, 'tree, Loader: LanguageLoader> {
162    query: QueryIter<'a, 'tree, HighlightQueryLoader<&'a Loader>, ()>,
163    next_query_event: Option<QueryIterEvent<'tree, ()>>,
164    /// The stack of currently active highlights.
165    /// The ranges of the highlights stack, so each highlight in the Vec must have a starting
166    /// point `>=` the starting point of the next highlight in the Vec and and ending point `<=`
167    /// the ending point of the next highlight in the Vec.
168    ///
169    /// For a visual:
170    ///
171    /// ```text
172    ///     | C |
173    ///   |   B   |
174    /// |     A    |
175    /// ```
176    ///
177    /// would be `vec![A, B, C]`.
178    active_highlights: Vec<HighlightedNode>,
179    next_highlight_end: u32,
180    next_highlight_start: u32,
181    active_config: Option<&'a LanguageConfig>,
182    // The current layer and per-layer state could be tracked on the QueryIter itself (see
183    // `QueryIter::current_layer` and `QueryIter::layer_state`) however the highlighter peeks the
184    // query iter. The query iter is always one event ahead, so it will enter/exit injections
185    // before we get a chance to in the highlighter. So instead we track these on the highlighter.
186    // Also see `Self::advance_query_iter`.
187    current_layer: Layer,
188    layer_states: HashMap<Layer, LayerData>,
189}
190
191pub struct HighlightList<'a>(slice::Iter<'a, HighlightedNode>);
192
193impl Iterator for HighlightList<'_> {
194    type Item = Highlight;
195
196    fn next(&mut self) -> Option<Highlight> {
197        self.0.next().map(|node| node.highlight)
198    }
199
200    fn size_hint(&self) -> (usize, Option<usize>) {
201        self.0.size_hint()
202    }
203}
204
205impl DoubleEndedIterator for HighlightList<'_> {
206    fn next_back(&mut self) -> Option<Self::Item> {
207        self.0.next_back().map(|node| node.highlight)
208    }
209}
210
211impl ExactSizeIterator for HighlightList<'_> {
212    fn len(&self) -> usize {
213        self.0.len()
214    }
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq)]
218pub enum HighlightEvent {
219    /// Reset the active set of highlights to the given ones.
220    Refresh,
221    /// Add more highlights which build on the existing highlights.
222    Push,
223}
224
225impl<'a, 'tree: 'a, Loader: LanguageLoader> Highlighter<'a, 'tree, Loader> {
226    pub fn new(
227        syntax: &'tree Syntax,
228        src: RopeSlice<'a>,
229        loader: &'a Loader,
230        range: impl RangeBounds<u32>,
231    ) -> Self {
232        let mut query = QueryIter::new(syntax, src, HighlightQueryLoader(loader), range);
233        let active_language = query.current_language();
234        let mut res = Highlighter {
235            active_config: query.loader().0.get_config(active_language),
236            next_query_event: None,
237            current_layer: query.current_layer(),
238            layer_states: Default::default(),
239            active_highlights: Vec::new(),
240            next_highlight_end: u32::MAX,
241            next_highlight_start: 0,
242            query,
243        };
244        res.advance_query_iter();
245        res
246    }
247
248    pub fn active_highlights(&self) -> HighlightList<'_> {
249        HighlightList(self.active_highlights.iter())
250    }
251
252    pub fn next_event_offset(&self) -> u32 {
253        self.next_highlight_start.min(self.next_highlight_end)
254    }
255
256    pub fn advance(&mut self) -> (HighlightEvent, HighlightList<'_>) {
257        let mut refresh = false;
258        let prev_stack_size = self.active_highlights.len();
259
260        let pos = self.next_event_offset();
261        if self.next_highlight_end == pos {
262            self.process_highlight_end(pos);
263            refresh = true;
264        }
265
266        let mut first_highlight = true;
267        while self.next_highlight_start == pos {
268            let Some(query_event) = self.advance_query_iter() else {
269                break;
270            };
271            match query_event {
272                QueryIterEvent::EnterInjection(injection) => self.enter_injection(injection.layer),
273                QueryIterEvent::Match(node) => self.start_highlight(node, &mut first_highlight),
274                QueryIterEvent::ExitInjection { injection, state } => {
275                    // `state` is returned if the layer is finished according to the `QueryIter`.
276                    // The highlighter should only consider a layer finished, though, when it also
277                    // has no remaining ranges to highlight. If the injection is combined and has
278                    // highlight(s) past this injection's range then we should deactivate it
279                    // (saving the highlights for the layer's next injection range) rather than
280                    // removing it.
281                    let parent_start = self
282                        .layer_states
283                        .get(&self.current_layer)
284                        .map(|layer| layer.parent_highlights)
285                        .unwrap_or_default()
286                        .min(self.active_highlights.len());
287                    let layer_is_finished = state.is_some()
288                        && self.active_highlights[parent_start..]
289                            .iter()
290                            .all(|h| h.end <= injection.range.end);
291                    if layer_is_finished {
292                        self.layer_states.remove(&injection.layer);
293                    } else {
294                        self.deactivate_layer(injection);
295                        refresh = true;
296                    }
297                    let active_language = self.query.syntax().layer(self.current_layer).language;
298                    self.active_config = self.query.loader().0.get_config(active_language);
299                }
300            }
301        }
302        self.next_highlight_end = self
303            .active_highlights
304            .last()
305            .map_or(u32::MAX, |node| node.end);
306
307        if refresh {
308            (
309                HighlightEvent::Refresh,
310                HighlightList(self.active_highlights.iter()),
311            )
312        } else {
313            (
314                HighlightEvent::Push,
315                HighlightList(self.active_highlights[prev_stack_size..].iter()),
316            )
317        }
318    }
319
320    fn advance_query_iter(&mut self) -> Option<QueryIterEvent<'tree, ()>> {
321        // Track the current layer **before** calling `QueryIter::next`. The QueryIter moves
322        // to the next event with `QueryIter::next` but we're treating that event as peeked - it
323        // hasn't occurred yet - so the current layer is the one the query iter was on _before_
324        // `QueryIter::next`.
325        self.current_layer = self.query.current_layer();
326        let event = replace(&mut self.next_query_event, self.query.next());
327        self.next_highlight_start = self
328            .next_query_event
329            .as_ref()
330            .map_or(u32::MAX, |event| event.start_byte());
331        event
332    }
333
334    fn process_highlight_end(&mut self, pos: u32) {
335        let i = self
336            .active_highlights
337            .iter()
338            .rposition(|highlight| highlight.end != pos)
339            .map_or(0, |i| i + 1);
340        self.active_highlights.truncate(i);
341    }
342
343    fn enter_injection(&mut self, layer: Layer) {
344        debug_assert_eq!(layer, self.current_layer);
345        let active_language = self.query.syntax().layer(layer).language;
346        self.active_config = self.query.loader().0.get_config(active_language);
347
348        let state = self.layer_states.entry(layer).or_default();
349        state.parent_highlights = self.active_highlights.len();
350        self.active_highlights.append(&mut state.dormant_highlights);
351    }
352
353    fn deactivate_layer(&mut self, injection: Injection) {
354        let LayerData {
355            mut parent_highlights,
356            ref mut dormant_highlights,
357            ..
358        } = self.layer_states.get_mut(&injection.layer).unwrap();
359        parent_highlights = parent_highlights.min(self.active_highlights.len());
360        dormant_highlights.extend(self.active_highlights.drain(parent_highlights..));
361        self.process_highlight_end(injection.range.end);
362    }
363
364    fn start_highlight(&mut self, node: MatchedNode, first_highlight: &mut bool) {
365        let range = node.node.byte_range();
366        // `<QueryIter as Iterator>::next` skips matches with empty ranges.
367        debug_assert!(
368            !range.is_empty(),
369            "QueryIter should not emit matches with empty ranges"
370        );
371
372        let config = self
373            .active_config
374            .expect("must have an active config to emit matches");
375
376        let highlight = if Some(node.capture) == config.highlight_query.local_reference_capture {
377            // If this capture was a `@local.reference` from the locals queries, look up the
378            // text of the node in the current locals cursor and use that highlight.
379            let text: Cow<str> = self
380                .query
381                .source()
382                .byte_slice(range.start as usize..range.end as usize)
383                .into();
384            let Some(definition) = self
385                .query
386                .syntax()
387                .layer(self.current_layer)
388                .locals
389                .lookup_reference(node.scope, &text)
390                .filter(|def| range.start >= def.range.end)
391            else {
392                return;
393            };
394            config
395                .injection_query
396                .local_definition_captures
397                .load()
398                .get(&definition.capture)
399                .copied()
400        } else {
401            config.highlight_query.highlight_indices.load()[node.capture.idx()]
402        };
403
404        let highlight = highlight.map(|highlight| HighlightedNode {
405            end: range.end,
406            highlight,
407        });
408
409        // If multiple patterns match this exact node, prefer the last one which matched.
410        // This matches the precedence of Neovim, Zed, and tree-sitter-cli.
411        if !*first_highlight {
412            // NOTE: `!*first_highlight` implies that the start positions are the same.
413            let insert_position = self
414                .active_highlights
415                .iter()
416                .rposition(|h| h.end <= range.end);
417            if let Some(idx) = insert_position {
418                match self.active_highlights[idx].end.cmp(&range.end) {
419                    // If there is a prior highlight for this start..end range, replace it.
420                    cmp::Ordering::Equal => {
421                        if let Some(highlight) = highlight {
422                            self.active_highlights[idx] = highlight;
423                        } else {
424                            self.active_highlights.remove(idx);
425                        }
426                    }
427                    // Captures are emitted in the order that they are finished. Insert any
428                    // highlights which start at the same position into the active highlights so
429                    // that the ordering invariant remains satisfied.
430                    cmp::Ordering::Less => {
431                        if let Some(highlight) = highlight {
432                            self.active_highlights.insert(idx, highlight)
433                        }
434                    }
435                    // By definition of our `rposition` predicate:
436                    cmp::Ordering::Greater => unreachable!(),
437                }
438            } else {
439                self.active_highlights.extend(highlight);
440            }
441        } else if let Some(highlight) = highlight {
442            self.active_highlights.push(highlight);
443            *first_highlight = false;
444        }
445
446        // `active_highlights` must be a stack of highlight events the highlights stack on the
447        // prior highlights in the Vec. Each highlight's range must be a subset of the highlight's
448        // range before it.
449        debug_assert!(
450            {
451                // The assertion is actually true for the entire stack but combined injections
452                // throw a wrench in things: the highlight can end after the current injection.
453                // The highlight is removed from `active_highlights` as the injection layer ends
454                // so the wider assertion would be true in practice. We don't track the injection
455                // end right here though so we can't assert on it.
456                let layer_start = self
457                    .layer_states
458                    .get(&self.current_layer)
459                    .map(|layer| layer.parent_highlights)
460                    .unwrap_or_default();
461
462                self.active_highlights[layer_start..].is_sorted_by_key(|h| cmp::Reverse(h.end))
463            },
464            "unsorted highlights on layer {:?}: {:?}\nall active highlights must be sorted by `end` descending",
465            self.current_layer,
466            self.active_highlights,
467        );
468    }
469}
470
471pub(crate) struct HighlightQueryLoader<T>(T);
472
473impl<'a, T: LanguageLoader> QueryLoader<'a> for HighlightQueryLoader<&'a T> {
474    fn get_query(&mut self, lang: Language) -> Option<&'a Query> {
475        self.0
476            .get_config(lang)
477            .map(|config| &config.highlight_query.query)
478    }
479
480    fn are_predicates_satisfied(
481        &self,
482        lang: Language,
483        mat: &QueryMatch<'_, '_>,
484        source: RopeSlice<'_>,
485        locals_cursor: &ScopeCursor<'_>,
486    ) -> bool {
487        let highlight_query = &self
488            .0
489            .get_config(lang)
490            .expect("must have a config to emit matches")
491            .highlight_query;
492
493        // Highlight queries should reject the match when a pattern is marked with
494        // `(#is-not? local)` and any capture in the pattern matches a definition in scope.
495        //
496        // TODO: in the future we should propose that `#is-not? local` takes one or more
497        // captures as arguments. Ideally we would check that the captured node is also captured
498        // by a `local.reference` capture from the locals query but that's really messy to pass
499        // around that information. For now we assume that all matches in the pattern are also
500        // captured as `local.reference` in the locals, which covers most cases.
501        if highlight_query.local_reference_capture.is_some()
502            && highlight_query.non_local_patterns.contains(&mat.pattern())
503        {
504            let has_local_reference = mat.matched_nodes().any(|n| {
505                let range = n.node.byte_range();
506                let text: Cow<str> = source
507                    .byte_slice(range.start as usize..range.end as usize)
508                    .into();
509                locals_cursor
510                    .locals
511                    .lookup_reference(locals_cursor.current_scope(), &text)
512                    .is_some_and(|def| range.start >= def.range.start)
513            });
514            if has_local_reference {
515                return false;
516            }
517        }
518
519        true
520    }
521}