tree_house/
query_iter.rs

1use core::slice;
2use std::iter::Peekable;
3use std::mem::replace;
4use std::ops::RangeBounds;
5
6use hashbrown::{HashMap, HashSet};
7use ropey::RopeSlice;
8
9use crate::{
10    locals::{Scope, ScopeCursor},
11    Injection, Language, Layer, Range, Syntax, TREE_SITTER_MATCH_LIMIT,
12};
13use tree_sitter::{
14    Capture, InactiveQueryCursor, Node, Pattern, Query, QueryCursor, QueryMatch, RopeInput,
15};
16
17#[derive(Debug, Clone)]
18pub struct MatchedNode<'tree> {
19    pub match_id: u32,
20    pub pattern: Pattern,
21    pub node: Node<'tree>,
22    pub capture: Capture,
23    pub scope: Scope,
24}
25
26struct LayerQueryIter<'a, 'tree> {
27    cursor: Option<QueryCursor<'a, 'tree, RopeInput<'a>>>,
28    peeked: Option<MatchedNode<'tree>>,
29    language: Language,
30    scope_cursor: ScopeCursor<'tree>,
31}
32
33impl<'a, 'tree> LayerQueryIter<'a, 'tree> {
34    fn peek<Loader: QueryLoader<'a>>(
35        &mut self,
36        source: RopeSlice<'_>,
37        loader: &Loader,
38    ) -> Option<&MatchedNode<'tree>> {
39        if self.peeked.is_none() {
40            loop {
41                // NOTE: we take the cursor here so that if `next_matched_node` is None the
42                // cursor is dropped and returned to the cache eagerly.
43                let mut cursor = self.cursor.take()?;
44                let (query_match, node_idx) = cursor.next_matched_node()?;
45                let node = query_match.matched_node(node_idx);
46                let match_id = query_match.id();
47                let pattern = query_match.pattern();
48                let range = node.node.byte_range();
49                let scope = self.scope_cursor.advance(range.start);
50
51                if !loader.are_predicates_satisfied(
52                    self.language,
53                    &query_match,
54                    source,
55                    &self.scope_cursor,
56                ) {
57                    query_match.remove();
58                    self.cursor = Some(cursor);
59                    continue;
60                }
61
62                self.peeked = Some(MatchedNode {
63                    match_id,
64                    pattern,
65                    // NOTE: `Node` is cheap to clone, it's essentially Copy.
66                    node: node.node.clone(),
67                    capture: node.capture,
68                    scope,
69                });
70                self.cursor = Some(cursor);
71                break;
72            }
73        }
74        self.peeked.as_ref()
75    }
76
77    fn consume(&mut self) -> MatchedNode<'tree> {
78        self.peeked.take().unwrap()
79    }
80}
81
82struct ActiveLayer<'a, 'tree, S> {
83    state: S,
84    query_iter: LayerQueryIter<'a, 'tree>,
85    injections: Peekable<slice::Iter<'a, Injection>>,
86}
87
88// data only needed when entering and exiting injections
89// separate struck to keep the QueryIter reasonably small
90struct QueryIterLayerManager<'a, 'tree, Loader, S> {
91    range: Range,
92    loader: Loader,
93    src: RopeSlice<'a>,
94    syntax: &'tree Syntax,
95    active_layers: HashMap<Layer, Box<ActiveLayer<'a, 'tree, S>>>,
96    active_injections: Vec<Injection>,
97    /// Layers which are known to have no more captures.
98    finished_layers: HashSet<Layer>,
99}
100
101impl<'a, 'tree: 'a, Loader, S> QueryIterLayerManager<'a, 'tree, Loader, S>
102where
103    Loader: QueryLoader<'a>,
104    S: Default,
105{
106    fn init_layer(&mut self, injection: Injection) -> Box<ActiveLayer<'a, 'tree, S>> {
107        self.active_layers
108            .remove(&injection.layer)
109            .unwrap_or_else(|| {
110                let layer = self.syntax.layer(injection.layer);
111                let start_point = injection.range.start.max(self.range.start);
112                let injection_start = layer
113                    .injections
114                    .partition_point(|child| child.range.end < start_point);
115                let cursor = if self.finished_layers.contains(&injection.layer) {
116                    // If the layer has no more captures, skip creating a cursor.
117                    None
118                } else {
119                    self.loader
120                        .get_query(layer.language)
121                        .and_then(|query| Some((query, layer.tree()?.root_node())))
122                        .map(|(query, node)| {
123                            InactiveQueryCursor::new(self.range.clone(), TREE_SITTER_MATCH_LIMIT)
124                                .execute_query(query, &node, RopeInput::new(self.src))
125                        })
126                };
127                Box::new(ActiveLayer {
128                    state: S::default(),
129                    query_iter: LayerQueryIter {
130                        language: layer.language,
131                        cursor,
132                        peeked: None,
133                        scope_cursor: layer.locals.scope_cursor(self.range.start),
134                    },
135                    injections: layer.injections[injection_start..].iter().peekable(),
136                })
137            })
138    }
139}
140
141pub struct QueryIter<'a, 'tree, Loader: QueryLoader<'a>, LayerState = ()> {
142    layer_manager: Box<QueryIterLayerManager<'a, 'tree, Loader, LayerState>>,
143    current_layer: Box<ActiveLayer<'a, 'tree, LayerState>>,
144    current_injection: Injection,
145}
146
147impl<'a, 'tree: 'a, Loader, LayerState> QueryIter<'a, 'tree, Loader, LayerState>
148where
149    Loader: QueryLoader<'a>,
150    LayerState: Default,
151{
152    pub fn new(
153        syntax: &'tree Syntax,
154        src: RopeSlice<'a>,
155        loader: Loader,
156        range: impl RangeBounds<u32>,
157    ) -> Self {
158        let start = match range.start_bound() {
159            std::ops::Bound::Included(&i) => i,
160            std::ops::Bound::Excluded(&i) => i + 1,
161            std::ops::Bound::Unbounded => 0,
162        };
163        let end = match range.end_bound() {
164            std::ops::Bound::Included(&i) => i + 1,
165            std::ops::Bound::Excluded(&i) => i,
166            std::ops::Bound::Unbounded => src.len_bytes() as u32,
167        };
168        let range = start..end;
169        let node = syntax.tree().root_node();
170        // create fake injection for query root
171        let injection = Injection {
172            range: node.byte_range(),
173            layer: syntax.root,
174            matched_node_range: node.byte_range(),
175        };
176        let mut layer_manager = Box::new(QueryIterLayerManager {
177            range,
178            loader,
179            src,
180            syntax,
181            // TODO: reuse allocations with an allocation pool
182            active_layers: HashMap::with_capacity(8),
183            active_injections: Vec::with_capacity(8),
184            finished_layers: HashSet::with_capacity(8),
185        });
186        Self {
187            current_layer: layer_manager.init_layer(injection.clone()),
188            current_injection: injection,
189            layer_manager,
190        }
191    }
192
193    #[inline]
194    pub fn source(&self) -> RopeSlice<'a> {
195        self.layer_manager.src
196    }
197
198    #[inline]
199    pub fn syntax(&self) -> &'tree Syntax {
200        self.layer_manager.syntax
201    }
202
203    #[inline]
204    pub fn loader(&mut self) -> &mut Loader {
205        &mut self.layer_manager.loader
206    }
207
208    #[inline]
209    pub fn current_layer(&self) -> Layer {
210        self.current_injection.layer
211    }
212
213    #[inline]
214    pub fn current_injection(&mut self) -> (Injection, &mut LayerState) {
215        (
216            self.current_injection.clone(),
217            &mut self.current_layer.state,
218        )
219    }
220
221    #[inline]
222    pub fn current_language(&self) -> Language {
223        self.layer_manager
224            .syntax
225            .layer(self.current_injection.layer)
226            .language
227    }
228
229    pub fn layer_state(&mut self, layer: Layer) -> &mut LayerState {
230        if layer == self.current_injection.layer {
231            &mut self.current_layer.state
232        } else {
233            &mut self
234                .layer_manager
235                .active_layers
236                .get_mut(&layer)
237                .unwrap()
238                .state
239        }
240    }
241
242    fn enter_injection(&mut self, injection: Injection) {
243        let active_layer = self.layer_manager.init_layer(injection.clone());
244        let old_injection = replace(&mut self.current_injection, injection);
245        let old_layer = replace(&mut self.current_layer, active_layer);
246        self.layer_manager
247            .active_layers
248            .insert(old_injection.layer, old_layer);
249        self.layer_manager.active_injections.push(old_injection);
250    }
251
252    fn exit_injection(&mut self) -> Option<(Injection, Option<LayerState>)> {
253        let injection = replace(
254            &mut self.current_injection,
255            self.layer_manager.active_injections.pop()?,
256        );
257        let mut layer = replace(
258            &mut self.current_layer,
259            self.layer_manager
260                .active_layers
261                .remove(&self.current_injection.layer)?,
262        );
263        let layer_unfinished =
264            layer.query_iter.peeked.is_some() || layer.injections.peek().is_some();
265        if layer_unfinished {
266            self.layer_manager
267                .active_layers
268                .insert(injection.layer, layer);
269            Some((injection, None))
270        } else {
271            self.layer_manager.finished_layers.insert(injection.layer);
272            Some((injection, Some(layer.state)))
273        }
274    }
275}
276
277impl<'a, 'tree: 'a, Loader, S> Iterator for QueryIter<'a, 'tree, Loader, S>
278where
279    Loader: QueryLoader<'a>,
280    S: Default,
281{
282    type Item = QueryIterEvent<'tree, S>;
283
284    fn next(&mut self) -> Option<Self::Item> {
285        loop {
286            let next_injection = self
287                .current_layer
288                .injections
289                .peek()
290                .filter(|injection| injection.range.start <= self.current_injection.range.end);
291            let next_match = self
292                .current_layer
293                .query_iter
294                .peek(self.layer_manager.src, &self.layer_manager.loader)
295                .filter(|matched_node| {
296                    matched_node.node.start_byte() <= self.current_injection.range.end
297                });
298
299            match (next_match, next_injection) {
300                (None, None) => {
301                    return self.exit_injection().map(|(injection, state)| {
302                        QueryIterEvent::ExitInjection { injection, state }
303                    });
304                }
305                (Some(mat), _) if mat.node.byte_range().is_empty() => {
306                    self.current_layer.query_iter.consume();
307                    continue;
308                }
309                (Some(_), None) => {
310                    // consume match
311                    let matched_node = self.current_layer.query_iter.consume();
312                    return Some(QueryIterEvent::Match(matched_node));
313                }
314                (Some(matched_node), Some(injection))
315                    if matched_node.node.start_byte() < injection.range.end =>
316                {
317                    // consume match
318                    let matched_node = self.current_layer.query_iter.consume();
319                    // ignore nodes that are overlapped by the injection
320                    if matched_node.node.start_byte() <= injection.range.start
321                        || injection.range.end < matched_node.node.end_byte()
322                    {
323                        return Some(QueryIterEvent::Match(matched_node));
324                    }
325                }
326                (Some(_), Some(_)) | (None, Some(_)) => {
327                    // consume injection
328                    let injection = self.current_layer.injections.next().unwrap();
329                    self.enter_injection(injection.clone());
330                    return Some(QueryIterEvent::EnterInjection(injection.clone()));
331                }
332            }
333        }
334    }
335}
336
337#[derive(Debug)]
338pub enum QueryIterEvent<'tree, State = ()> {
339    EnterInjection(Injection),
340    Match(MatchedNode<'tree>),
341    ExitInjection {
342        injection: Injection,
343        state: Option<State>,
344    },
345}
346
347impl<S> QueryIterEvent<'_, S> {
348    pub fn start_byte(&self) -> u32 {
349        match self {
350            QueryIterEvent::EnterInjection(injection) => injection.range.start,
351            QueryIterEvent::Match(mat) => mat.node.start_byte(),
352            QueryIterEvent::ExitInjection { injection, .. } => injection.range.end,
353        }
354    }
355}
356
357pub trait QueryLoader<'a> {
358    fn get_query(&mut self, lang: Language) -> Option<&'a Query>;
359
360    fn are_predicates_satisfied(
361        &self,
362        _lang: Language,
363        _match: &QueryMatch<'_, '_>,
364        _source: RopeSlice<'_>,
365        _locals_cursor: &ScopeCursor<'_>,
366    ) -> bool {
367        true
368    }
369}
370
371impl<'a, F> QueryLoader<'a> for F
372where
373    F: FnMut(Language) -> Option<&'a Query>,
374{
375    fn get_query(&mut self, lang: Language) -> Option<&'a Query> {
376        (self)(lang)
377    }
378}