tree_house_bindings/
query.rs

1use std::fmt::{self, Display};
2use std::ops::Range;
3use std::ptr::NonNull;
4use std::{slice, str};
5
6use crate::query::predicate::TextPredicate;
7pub use crate::query::predicate::{InvalidPredicateError, Predicate};
8use crate::Grammar;
9
10mod predicate;
11mod property;
12
13#[derive(Debug)]
14pub enum UserPredicate<'a> {
15    IsPropertySet {
16        negate: bool,
17        key: &'a str,
18        val: Option<&'a str>,
19    },
20    SetProperty {
21        key: &'a str,
22        val: Option<&'a str>,
23    },
24    Other(Predicate<'a>),
25}
26
27impl Display for UserPredicate<'_> {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        match *self {
30            UserPredicate::IsPropertySet { negate, key, val } => {
31                let predicate = if negate { "is-not?" } else { "is?" };
32                let spacer = if val.is_some() { " " } else { "" };
33                write!(f, " (#{predicate} {key}{spacer}{})", val.unwrap_or(""))
34            }
35            UserPredicate::SetProperty { key, val } => {
36                let spacer = if val.is_some() { " " } else { "" };
37                write!(f, "(#set! {key}{spacer}{})", val.unwrap_or(""))
38            }
39            UserPredicate::Other(ref predicate) => {
40                write!(f, "#{}", predicate.name())
41            }
42        }
43    }
44}
45
46#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
47pub struct Pattern(pub(crate) u32);
48
49impl Pattern {
50    pub const SENTINEL: Pattern = Pattern(u32::MAX);
51
52    pub fn idx(&self) -> usize {
53        self.0 as usize
54    }
55}
56
57pub enum QueryData {}
58
59#[derive(Debug)]
60pub(super) struct PatternData {
61    text_predicates: Range<u32>,
62}
63
64#[derive(Debug)]
65pub struct Query {
66    pub(crate) raw: NonNull<QueryData>,
67    num_captures: u32,
68    num_strings: u32,
69    text_predicates: Vec<TextPredicate>,
70    patterns: Box<[PatternData]>,
71}
72
73unsafe impl Send for Query {}
74unsafe impl Sync for Query {}
75
76impl Query {
77    /// Create a new query from a string containing one or more S-expression
78    /// patterns.
79    ///
80    /// The query is associated with a particular grammar, and can only be run
81    /// on syntax nodes parsed with that grammar. References to Queries can be
82    /// shared between multiple threads.
83    pub fn new(
84        grammar: Grammar,
85        source: &str,
86        mut custom_predicate: impl FnMut(Pattern, UserPredicate) -> Result<(), InvalidPredicateError>,
87    ) -> Result<Self, ParseError> {
88        assert!(
89            source.len() <= i32::MAX as usize,
90            "TreeSitter queries must be smaller then 2 GiB (is {})",
91            source.len() as f64 / 1024.0 / 1024.0 / 1024.0
92        );
93        let mut error_offset = 0u32;
94        let mut error_kind = RawQueryError::None;
95        let bytes = source.as_bytes();
96
97        // Compile the query.
98        let ptr = unsafe {
99            ts_query_new(
100                grammar,
101                bytes.as_ptr(),
102                bytes.len() as u32,
103                &mut error_offset,
104                &mut error_kind,
105            )
106        };
107
108        let Some(raw) = ptr else {
109            let offset = error_offset as usize;
110            let error_word = || {
111                source[offset..]
112                    .chars()
113                    .take_while(|&c| c.is_alphanumeric() || matches!(c, '_' | '-'))
114                    .collect()
115            };
116            let err = match error_kind {
117                RawQueryError::NodeType => {
118                    let node: String = error_word();
119                    ParseError::InvalidNodeType {
120                        location: ParserErrorLocation::new(source, offset, node.chars().count()),
121                        node,
122                    }
123                }
124                RawQueryError::Field => {
125                    let field = error_word();
126                    ParseError::InvalidFieldName {
127                        location: ParserErrorLocation::new(source, offset, field.chars().count()),
128                        field,
129                    }
130                }
131                RawQueryError::Capture => {
132                    let capture = error_word();
133                    ParseError::InvalidCaptureName {
134                        location: ParserErrorLocation::new(source, offset, capture.chars().count()),
135                        capture,
136                    }
137                }
138                RawQueryError::Syntax => {
139                    ParseError::SyntaxError(ParserErrorLocation::new(source, offset, 0))
140                }
141                RawQueryError::Structure => {
142                    ParseError::ImpossiblePattern(ParserErrorLocation::new(source, offset, 0))
143                }
144                RawQueryError::None => {
145                    unreachable!("tree-sitter returned a null pointer but did not set an error")
146                }
147                RawQueryError::Language => unreachable!("should be handled at grammar load"),
148            };
149            return Err(err);
150        };
151
152        // I am not going to bother with safety comments here, all of these are
153        // safe as long as TS is not buggy because raw is a properly constructed query
154        let num_captures = unsafe { ts_query_capture_count(raw) };
155        let num_strings = unsafe { ts_query_string_count(raw) };
156        let num_patterns = unsafe { ts_query_pattern_count(raw) };
157
158        let mut query = Query {
159            raw,
160            num_captures,
161            num_strings,
162            text_predicates: Vec::new(),
163            patterns: Box::default(),
164        };
165        let patterns: Result<_, ParseError> = (0..num_patterns)
166            .map(|pattern| {
167                query
168                    .parse_pattern_predicates(Pattern(pattern), &mut custom_predicate)
169                    .map_err(|err| {
170                        let pattern_start =
171                            unsafe { ts_query_start_byte_for_pattern(query.raw, pattern) as usize };
172                        match err {
173                            InvalidPredicateError::UnknownPredicate { name } => {
174                                let offset = source[pattern_start..]
175                                    .find(&*name)
176                                    .expect("predicate name is a substring of the query text")
177                                    + pattern_start
178                                    // Subtract a byte for b'#'.
179                                    - 1;
180                                ParseError::InvalidPredicate {
181                                    message: format!("unknown predicate #{name}"),
182                                    location: ParserErrorLocation::new(
183                                        source,
184                                        offset,
185                                        // Add one char for the '#'.
186                                        name.chars().count() + 1,
187                                    ),
188                                }
189                            }
190                            InvalidPredicateError::UnknownProperty { property } => {
191                                // TODO: this is naive. We should ensure that it is within a
192                                // `#set!` or `#is(-not)?`.
193                                let offset = source[pattern_start..]
194                                    .find(&*property)
195                                    .expect("property name is a substring of the query text")
196                                    + pattern_start;
197                                ParseError::InvalidPredicate {
198                                    message: format!("unknown property '{property}'"),
199                                    location: ParserErrorLocation::new(
200                                        source,
201                                        offset,
202                                        property.chars().count(),
203                                    ),
204                                }
205                            }
206                            InvalidPredicateError::Other { msg } => ParseError::InvalidPredicate {
207                                message: msg.into(),
208                                location: ParserErrorLocation::new(source, pattern_start, 0),
209                            },
210                        }
211                    })
212            })
213            .collect();
214        query.patterns = patterns?;
215        Ok(query)
216    }
217
218    #[inline]
219    fn get_string(&self, str: QueryStr) -> &str {
220        let value_id = str.0;
221        // need an assertions because the ts c api does not do bounds check
222        assert!(value_id <= self.num_strings, "invalid value index");
223        unsafe {
224            let mut len = 0;
225            let ptr = ts_query_string_value_for_id(self.raw, value_id, &mut len);
226            let data = slice::from_raw_parts(ptr, len as usize);
227            // safety: we only allow passing valid str(ings) as arguments to query::new
228            // name is always a substring of that. Treesitter does proper utf8 segmentation
229            // so any substrings it produces are codepoint aligned and therefore valid utf8
230            str::from_utf8_unchecked(data)
231        }
232    }
233
234    #[inline]
235    pub fn capture_name(&self, capture_idx: Capture) -> &str {
236        let capture_idx = capture_idx.0;
237        // need an assertions because the ts c api does not do bounds check
238        assert!(capture_idx <= self.num_captures, "invalid capture index");
239        let mut length = 0;
240        unsafe {
241            let ptr = ts_query_capture_name_for_id(self.raw, capture_idx, &mut length);
242            let name = slice::from_raw_parts(ptr, length as usize);
243            // safety: we only allow passing valid str(ings) as arguments to query::new
244            // name is always a substring of that. Treesitter does proper utf8 segmentation
245            // so any substrings it produces are codepoint aligned and therefore valid utf8
246            str::from_utf8_unchecked(name)
247        }
248    }
249
250    #[inline]
251    pub fn captures(&self) -> impl ExactSizeIterator<Item = (Capture, &str)> {
252        (0..self.num_captures).map(|cap| (Capture(cap), self.capture_name(Capture(cap))))
253    }
254
255    #[inline]
256    pub fn num_captures(&self) -> u32 {
257        self.num_captures
258    }
259
260    #[inline]
261    pub fn get_capture(&self, capture_name: &str) -> Option<Capture> {
262        for capture in 0..self.num_captures {
263            if capture_name == self.capture_name(Capture(capture)) {
264                return Some(Capture(capture));
265            }
266        }
267        None
268    }
269
270    pub(crate) fn pattern_text_predicates(&self, pattern_idx: u16) -> &[TextPredicate] {
271        let range = self.patterns[pattern_idx as usize].text_predicates.clone();
272        &self.text_predicates[range.start as usize..range.end as usize]
273    }
274
275    /// Get the byte offset where the given pattern starts in the query's
276    /// source.
277    #[doc(alias = "ts_query_start_byte_for_pattern")]
278    #[must_use]
279    pub fn start_byte_for_pattern(&self, pattern: Pattern) -> usize {
280        assert!(
281            pattern.0 < self.text_predicates.len() as u32,
282            "Pattern index is {pattern:?} but the pattern count is {}",
283            self.text_predicates.len(),
284        );
285        unsafe { ts_query_start_byte_for_pattern(self.raw, pattern.0) as usize }
286    }
287
288    /// Get the number of patterns in the query.
289    #[must_use]
290    pub fn pattern_count(&self) -> usize {
291        unsafe { ts_query_pattern_count(self.raw) as usize }
292    }
293    /// Get the number of patterns in the query.
294    #[must_use]
295    pub fn patterns(&self) -> impl ExactSizeIterator<Item = Pattern> {
296        (0..self.pattern_count() as u32).map(Pattern)
297    }
298
299    /// Disable a certain capture within a query.
300    ///
301    /// This prevents the capture from being returned in matches, and also avoids
302    /// any resource usage associated with recording the capture. Currently, there
303    /// is no way to undo this.
304    pub fn disable_capture(&mut self, name: &str) {
305        let bytes = name.as_bytes();
306        unsafe {
307            ts_query_disable_capture(self.raw, bytes.as_ptr(), bytes.len() as u32);
308        }
309    }
310}
311
312impl Drop for Query {
313    fn drop(&mut self) {
314        unsafe { ts_query_delete(self.raw) }
315    }
316}
317
318#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
319#[repr(transparent)]
320pub struct Capture(u32);
321
322impl Capture {
323    pub fn name(self, query: &Query) -> &str {
324        query.capture_name(self)
325    }
326    pub fn idx(self) -> usize {
327        self.0 as usize
328    }
329}
330
331/// A reference to a string stored in a query
332#[derive(Clone, Copy, Debug)]
333pub struct QueryStr(u32);
334
335impl QueryStr {
336    pub fn get(self, query: &Query) -> &str {
337        query.get_string(self)
338    }
339}
340
341#[derive(Debug, PartialEq, Eq)]
342pub struct ParserErrorLocation {
343    /// at which line the error occurred
344    pub line: u32,
345    /// at which codepoints/columns the errors starts in the line
346    pub column: u32,
347    /// how many codepoints/columns the error takes up
348    pub len: u32,
349    line_content: String,
350    line_before: Option<String>,
351    line_after: Option<String>,
352}
353
354impl ParserErrorLocation {
355    pub fn new(source: &str, start: usize, len: usize) -> ParserErrorLocation {
356        let mut line = 0;
357        let mut column = 0;
358        let mut line_content = String::new();
359        let mut line_before = None;
360        let mut line_after = None;
361
362        let mut byte_offset = 0;
363        for (this_line_no, this_line) in source.split('\n').enumerate() {
364            let line_start = byte_offset;
365            let line_end = line_start + this_line.len();
366            if line_start <= start && start <= line_end {
367                line = this_line_no;
368                line_content = this_line
369                    .strip_suffix('\r')
370                    .unwrap_or(this_line)
371                    .to_string();
372                column = source[line_start..start].chars().count();
373                line_before = source[..line_start]
374                    .lines()
375                    .next_back()
376                    .filter(|s| !s.is_empty())
377                    .map(ToOwned::to_owned);
378                line_after = source
379                    .get(line_end + 1..)
380                    .and_then(|rest| rest.lines().next())
381                    .filter(|s| !s.is_empty())
382                    .map(ToOwned::to_owned);
383                break;
384            }
385            byte_offset += this_line.len() + 1;
386        }
387
388        ParserErrorLocation {
389            line: line as u32,
390            column: column as u32,
391            len: len as u32,
392            line_content: line_content.to_owned(),
393            line_before,
394            line_after,
395        }
396    }
397}
398
399impl Display for ParserErrorLocation {
400    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401        writeln!(f, "  --> {}:{}", self.line + 1, self.column + 1)?;
402
403        let max_line_number = if self.line_after.is_some() {
404            self.line + 2
405        } else {
406            self.line + 1
407        };
408        let line_number_column_len = max_line_number.to_string().len();
409        let line = (self.line + 1).to_string();
410        let prefix = format!(" {:width$} |", "", width = line_number_column_len);
411
412        writeln!(f, "{prefix}")?;
413        if let Some(before) = self.line_before.as_ref() {
414            writeln!(f, " {} | {}", self.line, before)?;
415        }
416        writeln!(f, " {line} | {}", self.line_content)?;
417        writeln!(
418            f,
419            "{prefix}{:width$} {:^<len$}",
420            "",
421            "^",
422            width = self.column as usize,
423            len = self.len as usize
424        )?;
425        if let Some(after) = self.line_after.as_ref() {
426            writeln!(f, " {} | {}", self.line + 2, after)?;
427        }
428        writeln!(f, "{prefix}")
429    }
430}
431
432#[derive(thiserror::Error, Debug, PartialEq, Eq)]
433pub enum ParseError {
434    #[error("unexpected EOF")]
435    UnexpectedEof,
436    #[error("invalid query syntax\n{0}")]
437    SyntaxError(ParserErrorLocation),
438    #[error("invalid node type {node:?}\n{location}")]
439    InvalidNodeType {
440        node: String,
441        location: ParserErrorLocation,
442    },
443    #[error("invalid field name {field:?}\n{location}")]
444    InvalidFieldName {
445        field: String,
446        location: ParserErrorLocation,
447    },
448    #[error("invalid capture name {capture:?}\n{location}")]
449    InvalidCaptureName {
450        capture: String,
451        location: ParserErrorLocation,
452    },
453    #[error("{message}\n{location}")]
454    InvalidPredicate {
455        message: String,
456        location: ParserErrorLocation,
457    },
458    #[error("impossible pattern\n{0}")]
459    ImpossiblePattern(ParserErrorLocation),
460}
461
462#[repr(C)]
463// warns about never being constructed but it's constructed by C code
464// and wrwitten into a mutable reference
465#[allow(dead_code)]
466enum RawQueryError {
467    None = 0,
468    Syntax = 1,
469    NodeType = 2,
470    Field = 3,
471    Capture = 4,
472    Structure = 5,
473    Language = 6,
474}
475
476extern "C" {
477    /// Create a new query from a string containing one or more S-expression
478    /// patterns. The query is associated with a particular language, and can
479    /// only be run on syntax nodes parsed with that language. If all of the
480    /// given patterns are valid, this returns a `TSQuery`. If a pattern is
481    /// invalid, this returns `NULL`, and provides two pieces of information
482    /// about the problem: 1. The byte offset of the error is written to
483    /// the `error_offset` parameter. 2. The type of error is written to the
484    /// `error_type` parameter.
485    fn ts_query_new(
486        grammar: Grammar,
487        source: *const u8,
488        source_len: u32,
489        error_offset: &mut u32,
490        error_type: &mut RawQueryError,
491    ) -> Option<NonNull<QueryData>>;
492
493    /// Delete a query, freeing all of the memory that it used.
494    fn ts_query_delete(query: NonNull<QueryData>);
495
496    /// Get the number of patterns, captures, or string literals in the query.
497    fn ts_query_pattern_count(query: NonNull<QueryData>) -> u32;
498    fn ts_query_capture_count(query: NonNull<QueryData>) -> u32;
499    fn ts_query_string_count(query: NonNull<QueryData>) -> u32;
500
501    /// Get the byte offset where the given pattern starts in the query's
502    /// source. This can be useful when combining queries by concatenating their
503    /// source code strings.
504    fn ts_query_start_byte_for_pattern(query: NonNull<QueryData>, pattern_index: u32) -> u32;
505
506    // fn ts_query_is_pattern_rooted(query: NonNull<QueryData>, pattern_index: u32) -> bool;
507    // fn ts_query_is_pattern_non_local(query: NonNull<QueryData>, pattern_index: u32) -> bool;
508    // fn ts_query_is_pattern_guaranteed_at_step(query: NonNull<QueryData>, byte_offset: u32) -> bool;
509    /// Get the name and length of one of the query's captures, or one of the
510    /// query's string literals. Each capture and string is associated with a
511    /// numeric id based on the order that it appeared in the query's source.
512    fn ts_query_capture_name_for_id(
513        query: NonNull<QueryData>,
514        index: u32,
515        length: &mut u32,
516    ) -> *const u8;
517
518    fn ts_query_string_value_for_id(
519        self_: NonNull<QueryData>,
520        index: u32,
521        length: &mut u32,
522    ) -> *const u8;
523
524    /// Disable a certain capture within a query.
525    ///
526    /// This prevents the capture from being returned in matches, and also avoids
527    /// any resource usage associated with recording the capture. Currently, there
528    /// is no way to undo this.
529    fn ts_query_disable_capture(self_: NonNull<QueryData>, name: *const u8, length: u32);
530}