tree_house_bindings/query/
predicate.rs

1use std::error::Error;
2use std::iter::zip;
3use std::ops::Range;
4use std::ptr::NonNull;
5use std::{fmt, slice};
6
7use crate::query::property::QueryProperty;
8use crate::query::{Capture, Pattern, PatternData, Query, QueryData, QueryStr, UserPredicate};
9use crate::query_cursor::MatchedNode;
10use crate::Input;
11
12use regex_cursor::engines::meta::Regex;
13use regex_cursor::Cursor;
14
15macro_rules! bail {
16    ($($args:tt)*) => {{
17        return Err(InvalidPredicateError::Other {msg: format!($($args)*).into() })
18    }}
19}
20
21macro_rules! ensure {
22    ($cond: expr, $($args:tt)*) => {{
23        if !$cond {
24            return Err(InvalidPredicateError::Other { msg: format!($($args)*).into() })
25        }
26    }}
27}
28
29#[derive(Debug)]
30pub(super) enum TextPredicateKind {
31    EqString(QueryStr),
32    EqCapture(Capture),
33    MatchString(Regex),
34    AnyString(Box<[QueryStr]>),
35}
36
37#[derive(Debug)]
38pub(crate) struct TextPredicate {
39    capture: Capture,
40    kind: TextPredicateKind,
41    negated: bool,
42    match_all: bool,
43}
44
45fn input_matches_str<I: Input>(str: &str, range: Range<u32>, input: &mut I) -> bool {
46    if str.len() != range.len() {
47        return false;
48    }
49    let mut str = str.as_bytes();
50    let cursor = input.cursor_at(range.start);
51    let range = range.start as usize..range.end as usize;
52    let start_in_chunk = range.start - cursor.offset();
53    if range.end - cursor.offset() <= cursor.chunk().len() {
54        // hotpath
55        return &cursor.chunk()[start_in_chunk..range.end - cursor.offset()] == str;
56    }
57    if cursor.chunk()[start_in_chunk..] != str[..cursor.chunk().len() - start_in_chunk] {
58        return false;
59    }
60    str = &str[..cursor.chunk().len() - start_in_chunk];
61    while cursor.advance() {
62        if str.len() <= cursor.chunk().len() {
63            return &cursor.chunk()[..range.end - cursor.offset()] == str;
64        }
65        if &str[..cursor.chunk().len()] != cursor.chunk() {
66            return false;
67        }
68        str = &str[cursor.chunk().len()..]
69    }
70    // buggy cursor/invalid range
71    false
72}
73
74impl TextPredicate {
75    /// handlers match_all and negated
76    fn satisfied_helper(&self, mut nodes: impl Iterator<Item = bool>) -> bool {
77        if self.match_all {
78            nodes.all(|matched| matched != self.negated)
79        } else {
80            nodes.any(|matched| matched != self.negated)
81        }
82    }
83
84    pub fn satisfied<I: Input>(
85        &self,
86        input: &mut I,
87        matched_nodes: &[MatchedNode],
88        query: &Query,
89    ) -> bool {
90        let mut capture_nodes = matched_nodes
91            .iter()
92            .filter(|matched_node| matched_node.capture == self.capture);
93        match self.kind {
94            TextPredicateKind::EqString(str) => self.satisfied_helper(capture_nodes.map(|node| {
95                let range = node.node.byte_range();
96                input_matches_str(query.get_string(str), range.clone(), input)
97            })),
98            TextPredicateKind::EqCapture(other_capture) => {
99                let mut other_nodes = matched_nodes
100                    .iter()
101                    .filter(|matched_node| matched_node.capture == other_capture);
102
103                let res = self.satisfied_helper(zip(&mut capture_nodes, &mut other_nodes).map(
104                    |(node1, node2)| {
105                        let range1 = node1.node.byte_range();
106                        let range2 = node2.node.byte_range();
107                        input.eq(range1, range2)
108                    },
109                ));
110                let consumed_all = capture_nodes.next().is_none() && other_nodes.next().is_none();
111                res && (!self.match_all || consumed_all)
112            }
113            TextPredicateKind::MatchString(ref regex) => {
114                self.satisfied_helper(capture_nodes.map(|node| {
115                    let range = node.node.byte_range();
116                    let mut input = regex_cursor::Input::new(input.cursor_at(range.start));
117                    input.slice(range.start as usize..range.end as usize);
118                    regex.is_match(input)
119                }))
120            }
121            TextPredicateKind::AnyString(ref strings) => {
122                let strings = strings.iter().map(|&str| query.get_string(str));
123                self.satisfied_helper(capture_nodes.map(|node| {
124                    let range = node.node.byte_range();
125                    strings
126                        .clone()
127                        .filter(|str| str.len() == range.len())
128                        .any(|str| input_matches_str(str, range.clone(), input))
129                }))
130            }
131        }
132    }
133}
134
135impl Query {
136    pub(super) fn parse_pattern_predicates(
137        &mut self,
138        pattern: Pattern,
139        mut custom_predicate: impl FnMut(Pattern, UserPredicate) -> Result<(), InvalidPredicateError>,
140    ) -> Result<PatternData, InvalidPredicateError> {
141        let text_predicate_start = self.text_predicates.len() as u32;
142
143        let predicate_steps = unsafe {
144            let mut len = 0u32;
145            let raw_predicates = ts_query_predicates_for_pattern(self.raw, pattern.0, &mut len);
146            (len != 0)
147                .then(|| slice::from_raw_parts(raw_predicates, len as usize))
148                .unwrap_or_default()
149        };
150        let predicates = predicate_steps
151            .split(|step| step.kind == PredicateStepKind::Done)
152            .filter(|predicate| !predicate.is_empty());
153
154        for predicate in predicates {
155            let predicate = unsafe { Predicate::new(self, predicate)? };
156
157            match predicate.name() {
158                "eq?" | "not-eq?" | "any-eq?" | "any-not-eq?" => {
159                    predicate.check_arg_count(2)?;
160                    let capture_idx = predicate.capture_arg(0)?;
161                    let arg2 = predicate.arg(1);
162
163                    let negated = matches!(predicate.name(), "not-eq?" | "not-any-eq?");
164                    let match_all = matches!(predicate.name(), "eq?" | "not-eq?");
165                    let kind = match arg2 {
166                        PredicateArg::Capture(capture) => TextPredicateKind::EqCapture(capture),
167                        PredicateArg::String(str) => TextPredicateKind::EqString(str),
168                    };
169                    self.text_predicates.push(TextPredicate {
170                        capture: capture_idx,
171                        kind,
172                        negated,
173                        match_all,
174                    });
175                }
176
177                "match?" | "not-match?" | "any-match?" | "any-not-match?" => {
178                    predicate.check_arg_count(2)?;
179                    let capture_idx = predicate.capture_arg(0)?;
180                    let regex = predicate.query_str_arg(1)?.get(self);
181
182                    let negated = matches!(predicate.name(), "not-match?" | "any-not-match?");
183                    let match_all = matches!(predicate.name(), "match?" | "not-match?");
184                    let regex = match Regex::builder().build(regex) {
185                        Ok(regex) => regex,
186                        Err(err) => bail!("invalid regex '{regex}', {err}"),
187                    };
188                    self.text_predicates.push(TextPredicate {
189                        capture: capture_idx,
190                        kind: TextPredicateKind::MatchString(regex),
191                        negated,
192                        match_all,
193                    });
194                }
195
196                "set!" => {
197                    let property = QueryProperty::parse(&predicate)?;
198                    custom_predicate(
199                        pattern,
200                        UserPredicate::SetProperty {
201                            key: property.key.get(self),
202                            val: property.val.map(|val| val.get(self)),
203                        },
204                    )?
205                }
206                "is-not?" | "is?" => {
207                    let property = QueryProperty::parse(&predicate)?;
208                    custom_predicate(
209                        pattern,
210                        UserPredicate::IsPropertySet {
211                            negate: predicate.name() == "is-not?",
212                            key: property.key.get(self),
213                            val: property.val.map(|val| val.get(self)),
214                        },
215                    )?
216                }
217
218                "any-of?" | "not-any-of?" => {
219                    predicate.check_min_arg_count(1)?;
220                    let capture = predicate.capture_arg(0)?;
221                    let negated = predicate.name() == "not-any-of?";
222                    let values: Result<_, InvalidPredicateError> = (1..predicate.num_args())
223                        .map(|i| predicate.query_str_arg(i))
224                        .collect();
225                    self.text_predicates.push(TextPredicate {
226                        capture,
227                        kind: TextPredicateKind::AnyString(values?),
228                        negated,
229                        match_all: false,
230                    });
231                }
232
233                // is and is-not are better handled as custom predicates since interpreting is context dependent
234                // "is?" => property_predicates.push((QueryProperty::parse(&predicate), false)),
235                // "is-not?" => property_predicates.push((QueryProperty::parse(&predicate), true)),
236                _ => custom_predicate(pattern, UserPredicate::Other(predicate))?,
237            }
238        }
239        Ok(PatternData {
240            text_predicates: text_predicate_start..self.text_predicates.len() as u32,
241        })
242    }
243}
244
245pub enum PredicateArg {
246    Capture(Capture),
247    String(QueryStr),
248}
249
250#[derive(Debug, Clone, Copy)]
251pub struct Predicate<'a> {
252    pub name: QueryStr,
253    args: &'a [PredicateStep],
254    query: &'a Query,
255}
256
257impl<'a> Predicate<'a> {
258    unsafe fn new(
259        query: &'a Query,
260        predicate: &'a [PredicateStep],
261    ) -> Result<Predicate<'a>, InvalidPredicateError> {
262        ensure!(
263            predicate[0].kind == PredicateStepKind::String,
264            "expected predicate to start with a function name. Got @{}.",
265            Capture(predicate[0].value_id).name(query)
266        );
267        let operator_name = QueryStr(predicate[0].value_id);
268        Ok(Predicate {
269            name: operator_name,
270            args: &predicate[1..],
271            query,
272        })
273    }
274
275    pub fn name(&self) -> &str {
276        self.name.get(self.query)
277    }
278
279    pub fn check_arg_count(&self, n: usize) -> Result<(), InvalidPredicateError> {
280        ensure!(
281            self.args.len() == n,
282            "expected {n} arguments for #{}, got {}",
283            self.name(),
284            self.args.len()
285        );
286        Ok(())
287    }
288
289    pub fn check_min_arg_count(&self, n: usize) -> Result<(), InvalidPredicateError> {
290        ensure!(
291            n <= self.args.len(),
292            "expected at least {n} arguments for #{}, got {}",
293            self.name(),
294            self.args.len()
295        );
296        Ok(())
297    }
298
299    pub fn check_max_arg_count(&self, n: usize) -> Result<(), InvalidPredicateError> {
300        ensure!(
301            self.args.len() <= n,
302            "expected at most {n} arguments for #{}, got {}",
303            self.name(),
304            self.args.len()
305        );
306        Ok(())
307    }
308
309    pub fn query_str_arg(&self, i: usize) -> Result<QueryStr, InvalidPredicateError> {
310        match self.arg(i) {
311            PredicateArg::String(str) => Ok(str),
312            PredicateArg::Capture(capture) => bail!(
313                "{i}. argument to #{} must be a literal, got capture @{:?}",
314                self.name(),
315                capture.name(self.query)
316            ),
317        }
318    }
319
320    pub fn str_arg(&self, i: usize) -> Result<&str, InvalidPredicateError> {
321        Ok(self.query_str_arg(i)?.get(self.query))
322    }
323
324    pub fn num_args(&self) -> usize {
325        self.args.len()
326    }
327
328    pub fn capture_arg(&self, i: usize) -> Result<Capture, InvalidPredicateError> {
329        match self.arg(i) {
330            PredicateArg::Capture(capture) => Ok(capture),
331            PredicateArg::String(str) => bail!(
332                "{i}. argument to #{} expected a capture, got literal {:?}",
333                self.name(),
334                str.get(self.query)
335            ),
336        }
337    }
338
339    pub fn arg(&self, i: usize) -> PredicateArg {
340        self.args[i].try_into().unwrap()
341    }
342
343    pub fn args(&self) -> impl Iterator<Item = PredicateArg> + '_ {
344        self.args.iter().map(|&arg| arg.try_into().unwrap())
345    }
346}
347
348#[derive(Debug)]
349pub enum InvalidPredicateError {
350    /// The property specified in `#set! <prop>` is not known.
351    UnknownProperty {
352        property: Box<str>,
353    },
354    /// Predicate is unknown/unsupported by this query.
355    UnknownPredicate {
356        name: Box<str>,
357    },
358    Other {
359        msg: Box<str>,
360    },
361}
362
363impl InvalidPredicateError {
364    pub fn unknown(predicate: UserPredicate) -> Self {
365        match predicate {
366            UserPredicate::IsPropertySet { key, .. } => Self::UnknownProperty {
367                property: key.into(),
368            },
369            UserPredicate::SetProperty { key, .. } => Self::UnknownProperty {
370                property: key.into(),
371            },
372            UserPredicate::Other(predicate) => Self::UnknownPredicate {
373                name: predicate.name().into(),
374            },
375        }
376    }
377}
378
379impl From<String> for InvalidPredicateError {
380    fn from(value: String) -> Self {
381        InvalidPredicateError::Other {
382            msg: value.into_boxed_str(),
383        }
384    }
385}
386
387impl<'a> From<&'a str> for InvalidPredicateError {
388    fn from(value: &'a str) -> Self {
389        InvalidPredicateError::Other { msg: value.into() }
390    }
391}
392
393impl fmt::Display for InvalidPredicateError {
394    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395        match self {
396            Self::UnknownProperty { property } => write!(f, "unknown property '{property}'"),
397            Self::UnknownPredicate { name } => write!(f, "unknown predicate #{name}"),
398            Self::Other { msg } => f.write_str(msg),
399        }
400    }
401}
402
403impl Error for InvalidPredicateError {}
404
405#[repr(C)]
406#[derive(Debug, Clone, Copy, PartialEq, Eq)]
407// warns about never being constructed but it's constructed by C code
408// and written into a mutable reference
409#[allow(dead_code)]
410enum PredicateStepKind {
411    Done = 0,
412    Capture = 1,
413    String = 2,
414}
415
416#[repr(C)]
417#[derive(Debug, Clone, Copy)]
418struct PredicateStep {
419    kind: PredicateStepKind,
420    value_id: u32,
421}
422
423impl TryFrom<PredicateStep> for PredicateArg {
424    type Error = ();
425
426    fn try_from(step: PredicateStep) -> Result<Self, Self::Error> {
427        match step.kind {
428            PredicateStepKind::String => Ok(PredicateArg::String(QueryStr(step.value_id))),
429            PredicateStepKind::Capture => Ok(PredicateArg::Capture(Capture(step.value_id))),
430            PredicateStepKind::Done => Err(()),
431        }
432    }
433}
434
435extern "C" {
436    /// Get all of the predicates for the given pattern in the query. The
437    /// predicates are represented as a single array of steps. There are three
438    /// types of steps in this array, which correspond to the three legal values
439    /// for the `type` field:
440    ///
441    /// - `TSQueryPredicateStepTypeCapture` - Steps with this type represent names of captures.
442    ///   Their `value_id` can be used with the `ts_query_capture_name_for_id` function to
443    ///   obtain the name of the capture.
444    /// - `TSQueryPredicateStepTypeString` - Steps with this type represent literal strings.
445    ///   Their `value_id` can be used with the `ts_query_string_value_for_id` function to
446    ///   obtain their string value.
447    /// - `TSQueryPredicateStepTypeDone` - Steps with this type are *sentinels* that represent the
448    ///   end of an individual predicate. If a pattern has two predicates, then there will be two
449    ///   steps with this `type` in the array.
450    fn ts_query_predicates_for_pattern(
451        query: NonNull<QueryData>,
452        pattern_index: u32,
453        step_count: &mut u32,
454    ) -> *const PredicateStep;
455
456}