1use std::fmt::{self, Display};
2use std::ops::Range;
3use std::ptr::NonNull;
4use std::{slice, str};
5
6use crate::query::predicate::TextPredicate;
7pub use crate::query::predicate::{InvalidPredicateError, Predicate};
8use crate::Grammar;
9
10mod predicate;
11mod property;
12
13#[derive(Debug)]
14pub enum UserPredicate<'a> {
15 IsPropertySet {
16 negate: bool,
17 key: &'a str,
18 val: Option<&'a str>,
19 },
20 SetProperty {
21 key: &'a str,
22 val: Option<&'a str>,
23 },
24 Other(Predicate<'a>),
25}
26
27impl Display for UserPredicate<'_> {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 match *self {
30 UserPredicate::IsPropertySet { negate, key, val } => {
31 let predicate = if negate { "is-not?" } else { "is?" };
32 let spacer = if val.is_some() { " " } else { "" };
33 write!(f, " (#{predicate} {key}{spacer}{})", val.unwrap_or(""))
34 }
35 UserPredicate::SetProperty { key, val } => {
36 let spacer = if val.is_some() { " " } else { "" };
37 write!(f, "(#set! {key}{spacer}{})", val.unwrap_or(""))
38 }
39 UserPredicate::Other(ref predicate) => {
40 write!(f, "#{}", predicate.name())
41 }
42 }
43 }
44}
45
46#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
47pub struct Pattern(pub(crate) u32);
48
49impl Pattern {
50 pub const SENTINEL: Pattern = Pattern(u32::MAX);
51
52 pub fn idx(&self) -> usize {
53 self.0 as usize
54 }
55}
56
57pub enum QueryData {}
58
59#[derive(Debug)]
60pub(super) struct PatternData {
61 text_predicates: Range<u32>,
62}
63
64#[derive(Debug)]
65pub struct Query {
66 pub(crate) raw: NonNull<QueryData>,
67 num_captures: u32,
68 num_strings: u32,
69 text_predicates: Vec<TextPredicate>,
70 patterns: Box<[PatternData]>,
71}
72
73unsafe impl Send for Query {}
74unsafe impl Sync for Query {}
75
76impl Query {
77 pub fn new(
84 grammar: Grammar,
85 source: &str,
86 mut custom_predicate: impl FnMut(Pattern, UserPredicate) -> Result<(), InvalidPredicateError>,
87 ) -> Result<Self, ParseError> {
88 assert!(
89 source.len() <= i32::MAX as usize,
90 "TreeSitter queries must be smaller then 2 GiB (is {})",
91 source.len() as f64 / 1024.0 / 1024.0 / 1024.0
92 );
93 let mut error_offset = 0u32;
94 let mut error_kind = RawQueryError::None;
95 let bytes = source.as_bytes();
96
97 let ptr = unsafe {
99 ts_query_new(
100 grammar,
101 bytes.as_ptr(),
102 bytes.len() as u32,
103 &mut error_offset,
104 &mut error_kind,
105 )
106 };
107
108 let Some(raw) = ptr else {
109 let offset = error_offset as usize;
110 let error_word = || {
111 source[offset..]
112 .chars()
113 .take_while(|&c| c.is_alphanumeric() || matches!(c, '_' | '-'))
114 .collect()
115 };
116 let err = match error_kind {
117 RawQueryError::NodeType => {
118 let node: String = error_word();
119 ParseError::InvalidNodeType {
120 location: ParserErrorLocation::new(source, offset, node.chars().count()),
121 node,
122 }
123 }
124 RawQueryError::Field => {
125 let field = error_word();
126 ParseError::InvalidFieldName {
127 location: ParserErrorLocation::new(source, offset, field.chars().count()),
128 field,
129 }
130 }
131 RawQueryError::Capture => {
132 let capture = error_word();
133 ParseError::InvalidCaptureName {
134 location: ParserErrorLocation::new(source, offset, capture.chars().count()),
135 capture,
136 }
137 }
138 RawQueryError::Syntax => {
139 ParseError::SyntaxError(ParserErrorLocation::new(source, offset, 0))
140 }
141 RawQueryError::Structure => {
142 ParseError::ImpossiblePattern(ParserErrorLocation::new(source, offset, 0))
143 }
144 RawQueryError::None => {
145 unreachable!("tree-sitter returned a null pointer but did not set an error")
146 }
147 RawQueryError::Language => unreachable!("should be handled at grammar load"),
148 };
149 return Err(err);
150 };
151
152 let num_captures = unsafe { ts_query_capture_count(raw) };
155 let num_strings = unsafe { ts_query_string_count(raw) };
156 let num_patterns = unsafe { ts_query_pattern_count(raw) };
157
158 let mut query = Query {
159 raw,
160 num_captures,
161 num_strings,
162 text_predicates: Vec::new(),
163 patterns: Box::default(),
164 };
165 let patterns: Result<_, ParseError> = (0..num_patterns)
166 .map(|pattern| {
167 query
168 .parse_pattern_predicates(Pattern(pattern), &mut custom_predicate)
169 .map_err(|err| {
170 let pattern_start =
171 unsafe { ts_query_start_byte_for_pattern(query.raw, pattern) as usize };
172 match err {
173 InvalidPredicateError::UnknownPredicate { name } => {
174 let offset = source[pattern_start..]
175 .find(&*name)
176 .expect("predicate name is a substring of the query text")
177 + pattern_start
178 - 1;
180 ParseError::InvalidPredicate {
181 message: format!("unknown predicate #{name}"),
182 location: ParserErrorLocation::new(
183 source,
184 offset,
185 name.chars().count() + 1,
187 ),
188 }
189 }
190 InvalidPredicateError::UnknownProperty { property } => {
191 let offset = source[pattern_start..]
194 .find(&*property)
195 .expect("property name is a substring of the query text")
196 + pattern_start;
197 ParseError::InvalidPredicate {
198 message: format!("unknown property '{property}'"),
199 location: ParserErrorLocation::new(
200 source,
201 offset,
202 property.chars().count(),
203 ),
204 }
205 }
206 InvalidPredicateError::Other { msg } => ParseError::InvalidPredicate {
207 message: msg.into(),
208 location: ParserErrorLocation::new(source, pattern_start, 0),
209 },
210 }
211 })
212 })
213 .collect();
214 query.patterns = patterns?;
215 Ok(query)
216 }
217
218 #[inline]
219 fn get_string(&self, str: QueryStr) -> &str {
220 let value_id = str.0;
221 assert!(value_id <= self.num_strings, "invalid value index");
223 unsafe {
224 let mut len = 0;
225 let ptr = ts_query_string_value_for_id(self.raw, value_id, &mut len);
226 let data = slice::from_raw_parts(ptr, len as usize);
227 str::from_utf8_unchecked(data)
231 }
232 }
233
234 #[inline]
235 pub fn capture_name(&self, capture_idx: Capture) -> &str {
236 let capture_idx = capture_idx.0;
237 assert!(capture_idx <= self.num_captures, "invalid capture index");
239 let mut length = 0;
240 unsafe {
241 let ptr = ts_query_capture_name_for_id(self.raw, capture_idx, &mut length);
242 let name = slice::from_raw_parts(ptr, length as usize);
243 str::from_utf8_unchecked(name)
247 }
248 }
249
250 #[inline]
251 pub fn captures(&self) -> impl ExactSizeIterator<Item = (Capture, &str)> {
252 (0..self.num_captures).map(|cap| (Capture(cap), self.capture_name(Capture(cap))))
253 }
254
255 #[inline]
256 pub fn num_captures(&self) -> u32 {
257 self.num_captures
258 }
259
260 #[inline]
261 pub fn get_capture(&self, capture_name: &str) -> Option<Capture> {
262 for capture in 0..self.num_captures {
263 if capture_name == self.capture_name(Capture(capture)) {
264 return Some(Capture(capture));
265 }
266 }
267 None
268 }
269
270 pub(crate) fn pattern_text_predicates(&self, pattern_idx: u16) -> &[TextPredicate] {
271 let range = self.patterns[pattern_idx as usize].text_predicates.clone();
272 &self.text_predicates[range.start as usize..range.end as usize]
273 }
274
275 #[doc(alias = "ts_query_start_byte_for_pattern")]
278 #[must_use]
279 pub fn start_byte_for_pattern(&self, pattern: Pattern) -> usize {
280 assert!(
281 pattern.0 < self.text_predicates.len() as u32,
282 "Pattern index is {pattern:?} but the pattern count is {}",
283 self.text_predicates.len(),
284 );
285 unsafe { ts_query_start_byte_for_pattern(self.raw, pattern.0) as usize }
286 }
287
288 #[must_use]
290 pub fn pattern_count(&self) -> usize {
291 unsafe { ts_query_pattern_count(self.raw) as usize }
292 }
293 #[must_use]
295 pub fn patterns(&self) -> impl ExactSizeIterator<Item = Pattern> {
296 (0..self.pattern_count() as u32).map(Pattern)
297 }
298
299 pub fn disable_capture(&mut self, name: &str) {
305 let bytes = name.as_bytes();
306 unsafe {
307 ts_query_disable_capture(self.raw, bytes.as_ptr(), bytes.len() as u32);
308 }
309 }
310}
311
312impl Drop for Query {
313 fn drop(&mut self) {
314 unsafe { ts_query_delete(self.raw) }
315 }
316}
317
318#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
319#[repr(transparent)]
320pub struct Capture(u32);
321
322impl Capture {
323 pub fn name(self, query: &Query) -> &str {
324 query.capture_name(self)
325 }
326 pub fn idx(self) -> usize {
327 self.0 as usize
328 }
329}
330
331#[derive(Clone, Copy, Debug)]
333pub struct QueryStr(u32);
334
335impl QueryStr {
336 pub fn get(self, query: &Query) -> &str {
337 query.get_string(self)
338 }
339}
340
341#[derive(Debug, PartialEq, Eq)]
342pub struct ParserErrorLocation {
343 pub line: u32,
345 pub column: u32,
347 pub len: u32,
349 line_content: String,
350 line_before: Option<String>,
351 line_after: Option<String>,
352}
353
354impl ParserErrorLocation {
355 pub fn new(source: &str, start: usize, len: usize) -> ParserErrorLocation {
356 let mut line = 0;
357 let mut column = 0;
358 let mut line_content = String::new();
359 let mut line_before = None;
360 let mut line_after = None;
361
362 let mut byte_offset = 0;
363 for (this_line_no, this_line) in source.split('\n').enumerate() {
364 let line_start = byte_offset;
365 let line_end = line_start + this_line.len();
366 if line_start <= start && start <= line_end {
367 line = this_line_no;
368 line_content = this_line
369 .strip_suffix('\r')
370 .unwrap_or(this_line)
371 .to_string();
372 column = source[line_start..start].chars().count();
373 line_before = source[..line_start]
374 .lines()
375 .next_back()
376 .filter(|s| !s.is_empty())
377 .map(ToOwned::to_owned);
378 line_after = source
379 .get(line_end + 1..)
380 .and_then(|rest| rest.lines().next())
381 .filter(|s| !s.is_empty())
382 .map(ToOwned::to_owned);
383 break;
384 }
385 byte_offset += this_line.len() + 1;
386 }
387
388 ParserErrorLocation {
389 line: line as u32,
390 column: column as u32,
391 len: len as u32,
392 line_content: line_content.to_owned(),
393 line_before,
394 line_after,
395 }
396 }
397}
398
399impl Display for ParserErrorLocation {
400 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401 writeln!(f, " --> {}:{}", self.line + 1, self.column + 1)?;
402
403 let max_line_number = if self.line_after.is_some() {
404 self.line + 2
405 } else {
406 self.line + 1
407 };
408 let line_number_column_len = max_line_number.to_string().len();
409 let line = (self.line + 1).to_string();
410 let prefix = format!(" {:width$} |", "", width = line_number_column_len);
411
412 writeln!(f, "{prefix}")?;
413 if let Some(before) = self.line_before.as_ref() {
414 writeln!(f, " {} | {}", self.line, before)?;
415 }
416 writeln!(f, " {line} | {}", self.line_content)?;
417 writeln!(
418 f,
419 "{prefix}{:width$} {:^<len$}",
420 "",
421 "^",
422 width = self.column as usize,
423 len = self.len as usize
424 )?;
425 if let Some(after) = self.line_after.as_ref() {
426 writeln!(f, " {} | {}", self.line + 2, after)?;
427 }
428 writeln!(f, "{prefix}")
429 }
430}
431
432#[derive(thiserror::Error, Debug, PartialEq, Eq)]
433pub enum ParseError {
434 #[error("unexpected EOF")]
435 UnexpectedEof,
436 #[error("invalid query syntax\n{0}")]
437 SyntaxError(ParserErrorLocation),
438 #[error("invalid node type {node:?}\n{location}")]
439 InvalidNodeType {
440 node: String,
441 location: ParserErrorLocation,
442 },
443 #[error("invalid field name {field:?}\n{location}")]
444 InvalidFieldName {
445 field: String,
446 location: ParserErrorLocation,
447 },
448 #[error("invalid capture name {capture:?}\n{location}")]
449 InvalidCaptureName {
450 capture: String,
451 location: ParserErrorLocation,
452 },
453 #[error("{message}\n{location}")]
454 InvalidPredicate {
455 message: String,
456 location: ParserErrorLocation,
457 },
458 #[error("impossible pattern\n{0}")]
459 ImpossiblePattern(ParserErrorLocation),
460}
461
462#[repr(C)]
463#[allow(dead_code)]
466enum RawQueryError {
467 None = 0,
468 Syntax = 1,
469 NodeType = 2,
470 Field = 3,
471 Capture = 4,
472 Structure = 5,
473 Language = 6,
474}
475
476extern "C" {
477 fn ts_query_new(
486 grammar: Grammar,
487 source: *const u8,
488 source_len: u32,
489 error_offset: &mut u32,
490 error_type: &mut RawQueryError,
491 ) -> Option<NonNull<QueryData>>;
492
493 fn ts_query_delete(query: NonNull<QueryData>);
495
496 fn ts_query_pattern_count(query: NonNull<QueryData>) -> u32;
498 fn ts_query_capture_count(query: NonNull<QueryData>) -> u32;
499 fn ts_query_string_count(query: NonNull<QueryData>) -> u32;
500
501 fn ts_query_start_byte_for_pattern(query: NonNull<QueryData>, pattern_index: u32) -> u32;
505
506 fn ts_query_capture_name_for_id(
513 query: NonNull<QueryData>,
514 index: u32,
515 length: &mut u32,
516 ) -> *const u8;
517
518 fn ts_query_string_value_for_id(
519 self_: NonNull<QueryData>,
520 index: u32,
521 length: &mut u32,
522 ) -> *const u8;
523
524 fn ts_query_disable_capture(self_: NonNull<QueryData>, name: *const u8, length: u32);
530}