1use std::error::Error;
2use std::iter::zip;
3use std::ops::Range;
4use std::ptr::NonNull;
5use std::{fmt, slice};
6
7use crate::query::property::QueryProperty;
8use crate::query::{Capture, Pattern, PatternData, Query, QueryData, QueryStr, UserPredicate};
9use crate::query_cursor::MatchedNode;
10use crate::Input;
11
12use regex_cursor::engines::meta::Regex;
13use regex_cursor::Cursor;
14
15macro_rules! bail {
16 ($($args:tt)*) => {{
17 return Err(InvalidPredicateError::Other {msg: format!($($args)*).into() })
18 }}
19}
20
21macro_rules! ensure {
22 ($cond: expr, $($args:tt)*) => {{
23 if !$cond {
24 return Err(InvalidPredicateError::Other { msg: format!($($args)*).into() })
25 }
26 }}
27}
28
29#[derive(Debug)]
30pub(super) enum TextPredicateKind {
31 EqString(QueryStr),
32 EqCapture(Capture),
33 MatchString(Regex),
34 AnyString(Box<[QueryStr]>),
35}
36
37#[derive(Debug)]
38pub(crate) struct TextPredicate {
39 capture: Capture,
40 kind: TextPredicateKind,
41 negated: bool,
42 match_all: bool,
43}
44
45fn input_matches_str<I: Input>(str: &str, range: Range<u32>, input: &mut I) -> bool {
46 if str.len() != range.len() {
47 return false;
48 }
49 let mut str = str.as_bytes();
50 let cursor = input.cursor_at(range.start);
51 let range = range.start as usize..range.end as usize;
52 let start_in_chunk = range.start - cursor.offset();
53 if range.end - cursor.offset() <= cursor.chunk().len() {
54 return &cursor.chunk()[start_in_chunk..range.end - cursor.offset()] == str;
56 }
57 if cursor.chunk()[start_in_chunk..] != str[..cursor.chunk().len() - start_in_chunk] {
58 return false;
59 }
60 str = &str[..cursor.chunk().len() - start_in_chunk];
61 while cursor.advance() {
62 if str.len() <= cursor.chunk().len() {
63 return &cursor.chunk()[..range.end - cursor.offset()] == str;
64 }
65 if &str[..cursor.chunk().len()] != cursor.chunk() {
66 return false;
67 }
68 str = &str[cursor.chunk().len()..]
69 }
70 false
72}
73
74impl TextPredicate {
75 fn satisfied_helper(&self, mut nodes: impl Iterator<Item = bool>) -> bool {
77 if self.match_all {
78 nodes.all(|matched| matched != self.negated)
79 } else {
80 nodes.any(|matched| matched != self.negated)
81 }
82 }
83
84 pub fn satisfied<I: Input>(
85 &self,
86 input: &mut I,
87 matched_nodes: &[MatchedNode],
88 query: &Query,
89 ) -> bool {
90 let mut capture_nodes = matched_nodes
91 .iter()
92 .filter(|matched_node| matched_node.capture == self.capture);
93 match self.kind {
94 TextPredicateKind::EqString(str) => self.satisfied_helper(capture_nodes.map(|node| {
95 let range = node.node.byte_range();
96 input_matches_str(query.get_string(str), range.clone(), input)
97 })),
98 TextPredicateKind::EqCapture(other_capture) => {
99 let mut other_nodes = matched_nodes
100 .iter()
101 .filter(|matched_node| matched_node.capture == other_capture);
102
103 let res = self.satisfied_helper(zip(&mut capture_nodes, &mut other_nodes).map(
104 |(node1, node2)| {
105 let range1 = node1.node.byte_range();
106 let range2 = node2.node.byte_range();
107 input.eq(range1, range2)
108 },
109 ));
110 let consumed_all = capture_nodes.next().is_none() && other_nodes.next().is_none();
111 res && (!self.match_all || consumed_all)
112 }
113 TextPredicateKind::MatchString(ref regex) => {
114 self.satisfied_helper(capture_nodes.map(|node| {
115 let range = node.node.byte_range();
116 let mut input = regex_cursor::Input::new(input.cursor_at(range.start));
117 input.slice(range.start as usize..range.end as usize);
118 regex.is_match(input)
119 }))
120 }
121 TextPredicateKind::AnyString(ref strings) => {
122 let strings = strings.iter().map(|&str| query.get_string(str));
123 self.satisfied_helper(capture_nodes.map(|node| {
124 let range = node.node.byte_range();
125 strings
126 .clone()
127 .filter(|str| str.len() == range.len())
128 .any(|str| input_matches_str(str, range.clone(), input))
129 }))
130 }
131 }
132 }
133}
134
135impl Query {
136 pub(super) fn parse_pattern_predicates(
137 &mut self,
138 pattern: Pattern,
139 mut custom_predicate: impl FnMut(Pattern, UserPredicate) -> Result<(), InvalidPredicateError>,
140 ) -> Result<PatternData, InvalidPredicateError> {
141 let text_predicate_start = self.text_predicates.len() as u32;
142
143 let predicate_steps = unsafe {
144 let mut len = 0u32;
145 let raw_predicates = ts_query_predicates_for_pattern(self.raw, pattern.0, &mut len);
146 (len != 0)
147 .then(|| slice::from_raw_parts(raw_predicates, len as usize))
148 .unwrap_or_default()
149 };
150 let predicates = predicate_steps
151 .split(|step| step.kind == PredicateStepKind::Done)
152 .filter(|predicate| !predicate.is_empty());
153
154 for predicate in predicates {
155 let predicate = unsafe { Predicate::new(self, predicate)? };
156
157 match predicate.name() {
158 "eq?" | "not-eq?" | "any-eq?" | "any-not-eq?" => {
159 predicate.check_arg_count(2)?;
160 let capture_idx = predicate.capture_arg(0)?;
161 let arg2 = predicate.arg(1);
162
163 let negated = matches!(predicate.name(), "not-eq?" | "not-any-eq?");
164 let match_all = matches!(predicate.name(), "eq?" | "not-eq?");
165 let kind = match arg2 {
166 PredicateArg::Capture(capture) => TextPredicateKind::EqCapture(capture),
167 PredicateArg::String(str) => TextPredicateKind::EqString(str),
168 };
169 self.text_predicates.push(TextPredicate {
170 capture: capture_idx,
171 kind,
172 negated,
173 match_all,
174 });
175 }
176
177 "match?" | "not-match?" | "any-match?" | "any-not-match?" => {
178 predicate.check_arg_count(2)?;
179 let capture_idx = predicate.capture_arg(0)?;
180 let regex = predicate.query_str_arg(1)?.get(self);
181
182 let negated = matches!(predicate.name(), "not-match?" | "any-not-match?");
183 let match_all = matches!(predicate.name(), "match?" | "not-match?");
184 let regex = match Regex::builder().build(regex) {
185 Ok(regex) => regex,
186 Err(err) => bail!("invalid regex '{regex}', {err}"),
187 };
188 self.text_predicates.push(TextPredicate {
189 capture: capture_idx,
190 kind: TextPredicateKind::MatchString(regex),
191 negated,
192 match_all,
193 });
194 }
195
196 "set!" => {
197 let property = QueryProperty::parse(&predicate)?;
198 custom_predicate(
199 pattern,
200 UserPredicate::SetProperty {
201 key: property.key.get(self),
202 val: property.val.map(|val| val.get(self)),
203 },
204 )?
205 }
206 "is-not?" | "is?" => {
207 let property = QueryProperty::parse(&predicate)?;
208 custom_predicate(
209 pattern,
210 UserPredicate::IsPropertySet {
211 negate: predicate.name() == "is-not?",
212 key: property.key.get(self),
213 val: property.val.map(|val| val.get(self)),
214 },
215 )?
216 }
217
218 "any-of?" | "not-any-of?" => {
219 predicate.check_min_arg_count(1)?;
220 let capture = predicate.capture_arg(0)?;
221 let negated = predicate.name() == "not-any-of?";
222 let values: Result<_, InvalidPredicateError> = (1..predicate.num_args())
223 .map(|i| predicate.query_str_arg(i))
224 .collect();
225 self.text_predicates.push(TextPredicate {
226 capture,
227 kind: TextPredicateKind::AnyString(values?),
228 negated,
229 match_all: false,
230 });
231 }
232
233 _ => custom_predicate(pattern, UserPredicate::Other(predicate))?,
237 }
238 }
239 Ok(PatternData {
240 text_predicates: text_predicate_start..self.text_predicates.len() as u32,
241 })
242 }
243}
244
245pub enum PredicateArg {
246 Capture(Capture),
247 String(QueryStr),
248}
249
250#[derive(Debug, Clone, Copy)]
251pub struct Predicate<'a> {
252 pub name: QueryStr,
253 args: &'a [PredicateStep],
254 query: &'a Query,
255}
256
257impl<'a> Predicate<'a> {
258 unsafe fn new(
259 query: &'a Query,
260 predicate: &'a [PredicateStep],
261 ) -> Result<Predicate<'a>, InvalidPredicateError> {
262 ensure!(
263 predicate[0].kind == PredicateStepKind::String,
264 "expected predicate to start with a function name. Got @{}.",
265 Capture(predicate[0].value_id).name(query)
266 );
267 let operator_name = QueryStr(predicate[0].value_id);
268 Ok(Predicate {
269 name: operator_name,
270 args: &predicate[1..],
271 query,
272 })
273 }
274
275 pub fn name(&self) -> &str {
276 self.name.get(self.query)
277 }
278
279 pub fn check_arg_count(&self, n: usize) -> Result<(), InvalidPredicateError> {
280 ensure!(
281 self.args.len() == n,
282 "expected {n} arguments for #{}, got {}",
283 self.name(),
284 self.args.len()
285 );
286 Ok(())
287 }
288
289 pub fn check_min_arg_count(&self, n: usize) -> Result<(), InvalidPredicateError> {
290 ensure!(
291 n <= self.args.len(),
292 "expected at least {n} arguments for #{}, got {}",
293 self.name(),
294 self.args.len()
295 );
296 Ok(())
297 }
298
299 pub fn check_max_arg_count(&self, n: usize) -> Result<(), InvalidPredicateError> {
300 ensure!(
301 self.args.len() <= n,
302 "expected at most {n} arguments for #{}, got {}",
303 self.name(),
304 self.args.len()
305 );
306 Ok(())
307 }
308
309 pub fn query_str_arg(&self, i: usize) -> Result<QueryStr, InvalidPredicateError> {
310 match self.arg(i) {
311 PredicateArg::String(str) => Ok(str),
312 PredicateArg::Capture(capture) => bail!(
313 "{i}. argument to #{} must be a literal, got capture @{:?}",
314 self.name(),
315 capture.name(self.query)
316 ),
317 }
318 }
319
320 pub fn str_arg(&self, i: usize) -> Result<&str, InvalidPredicateError> {
321 Ok(self.query_str_arg(i)?.get(self.query))
322 }
323
324 pub fn num_args(&self) -> usize {
325 self.args.len()
326 }
327
328 pub fn capture_arg(&self, i: usize) -> Result<Capture, InvalidPredicateError> {
329 match self.arg(i) {
330 PredicateArg::Capture(capture) => Ok(capture),
331 PredicateArg::String(str) => bail!(
332 "{i}. argument to #{} expected a capture, got literal {:?}",
333 self.name(),
334 str.get(self.query)
335 ),
336 }
337 }
338
339 pub fn arg(&self, i: usize) -> PredicateArg {
340 self.args[i].try_into().unwrap()
341 }
342
343 pub fn args(&self) -> impl Iterator<Item = PredicateArg> + '_ {
344 self.args.iter().map(|&arg| arg.try_into().unwrap())
345 }
346}
347
348#[derive(Debug)]
349pub enum InvalidPredicateError {
350 UnknownProperty {
352 property: Box<str>,
353 },
354 UnknownPredicate {
356 name: Box<str>,
357 },
358 Other {
359 msg: Box<str>,
360 },
361}
362
363impl InvalidPredicateError {
364 pub fn unknown(predicate: UserPredicate) -> Self {
365 match predicate {
366 UserPredicate::IsPropertySet { key, .. } => Self::UnknownProperty {
367 property: key.into(),
368 },
369 UserPredicate::SetProperty { key, .. } => Self::UnknownProperty {
370 property: key.into(),
371 },
372 UserPredicate::Other(predicate) => Self::UnknownPredicate {
373 name: predicate.name().into(),
374 },
375 }
376 }
377}
378
379impl From<String> for InvalidPredicateError {
380 fn from(value: String) -> Self {
381 InvalidPredicateError::Other {
382 msg: value.into_boxed_str(),
383 }
384 }
385}
386
387impl<'a> From<&'a str> for InvalidPredicateError {
388 fn from(value: &'a str) -> Self {
389 InvalidPredicateError::Other { msg: value.into() }
390 }
391}
392
393impl fmt::Display for InvalidPredicateError {
394 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395 match self {
396 Self::UnknownProperty { property } => write!(f, "unknown property '{property}'"),
397 Self::UnknownPredicate { name } => write!(f, "unknown predicate #{name}"),
398 Self::Other { msg } => f.write_str(msg),
399 }
400 }
401}
402
403impl Error for InvalidPredicateError {}
404
405#[repr(C)]
406#[derive(Debug, Clone, Copy, PartialEq, Eq)]
407#[allow(dead_code)]
410enum PredicateStepKind {
411 Done = 0,
412 Capture = 1,
413 String = 2,
414}
415
416#[repr(C)]
417#[derive(Debug, Clone, Copy)]
418struct PredicateStep {
419 kind: PredicateStepKind,
420 value_id: u32,
421}
422
423impl TryFrom<PredicateStep> for PredicateArg {
424 type Error = ();
425
426 fn try_from(step: PredicateStep) -> Result<Self, Self::Error> {
427 match step.kind {
428 PredicateStepKind::String => Ok(PredicateArg::String(QueryStr(step.value_id))),
429 PredicateStepKind::Capture => Ok(PredicateArg::Capture(Capture(step.value_id))),
430 PredicateStepKind::Done => Err(()),
431 }
432 }
433}
434
435extern "C" {
436 fn ts_query_predicates_for_pattern(
451 query: NonNull<QueryData>,
452 pattern_index: u32,
453 step_count: &mut u32,
454 ) -> *const PredicateStep;
455
456}