1use core::slice;
2use std::cell::UnsafeCell;
3use std::marker::PhantomData;
4use std::mem;
5use std::ops::Range;
6use std::ptr::{self, NonNull};
7
8use crate::node::NodeRaw;
9use crate::query::{Capture, Pattern, Query, QueryData};
10use crate::{Input, IntoInput, Node, Tree};
11
12enum QueryCursorData {}
13
14thread_local! {
15 static CURSOR_CACHE: UnsafeCell<Vec<InactiveQueryCursor>> = UnsafeCell::new(Vec::with_capacity(8));
16}
17
18unsafe fn with_cache<T>(f: impl FnOnce(&mut Vec<InactiveQueryCursor>) -> T) -> T {
20 CURSOR_CACHE.with(|cache| f(&mut *cache.get()))
21}
22
23pub struct QueryCursor<'a, 'tree, I: Input> {
24 query: &'a Query,
25 ptr: NonNull<QueryCursorData>,
26 tree: PhantomData<&'tree Tree>,
27 input: I,
28}
29
30impl<'tree, I: Input> QueryCursor<'_, 'tree, I> {
31 pub fn next_match(&mut self) -> Option<QueryMatch<'_, 'tree>> {
32 let mut query_match = TSQueryMatch {
33 id: 0,
34 pattern_index: 0,
35 capture_count: 0,
36 captures: ptr::null(),
37 };
38 loop {
39 let success =
40 unsafe { ts_query_cursor_next_match(self.ptr.as_ptr(), &mut query_match) };
41 if !success {
42 return None;
43 }
44 let matched_nodes = unsafe {
45 slice::from_raw_parts(
46 query_match.captures.cast(),
47 query_match.capture_count as usize,
48 )
49 };
50 let satisfies_predicates = self
51 .query
52 .pattern_text_predicates(query_match.pattern_index)
53 .iter()
54 .all(|predicate| predicate.satisfied(&mut self.input, matched_nodes, self.query));
55 if satisfies_predicates {
56 let res = QueryMatch {
57 id: query_match.id,
58 pattern: Pattern(query_match.pattern_index as u32),
59 matched_nodes,
60 query_cursor: unsafe { self.ptr.as_mut() },
61 _tree: PhantomData,
62 };
63 return Some(res);
64 }
65 }
66 }
67
68 pub fn next_matched_node(&mut self) -> Option<(QueryMatch<'_, 'tree>, MatchedNodeIdx)> {
69 let mut query_match = TSQueryMatch {
70 id: 0,
71 pattern_index: 0,
72 capture_count: 0,
73 captures: ptr::null(),
74 };
75 let mut capture_idx = 0;
76 loop {
77 let success = unsafe {
78 ts_query_cursor_next_capture(self.ptr.as_ptr(), &mut query_match, &mut capture_idx)
79 };
80 if !success {
81 return None;
82 }
83 let matched_nodes = unsafe {
84 slice::from_raw_parts(
85 query_match.captures.cast(),
86 query_match.capture_count as usize,
87 )
88 };
89 let satisfies_predicates = self
90 .query
91 .pattern_text_predicates(query_match.pattern_index)
92 .iter()
93 .all(|predicate| predicate.satisfied(&mut self.input, matched_nodes, self.query));
94 if satisfies_predicates {
95 let res = QueryMatch {
96 id: query_match.id,
97 pattern: Pattern(query_match.pattern_index as u32),
98 matched_nodes,
99 query_cursor: unsafe { self.ptr.as_mut() },
100 _tree: PhantomData,
101 };
102 return Some((res, capture_idx));
103 } else {
104 unsafe {
105 ts_query_cursor_remove_match(self.ptr.as_ptr(), query_match.id);
106 }
107 }
108 }
109 }
110
111 pub fn set_byte_range(&mut self, range: Range<u32>) {
112 unsafe {
113 ts_query_cursor_set_byte_range(self.ptr.as_ptr(), range.start, range.end);
114 }
115 }
116
117 pub fn reuse(self) -> InactiveQueryCursor {
118 let res = InactiveQueryCursor { ptr: self.ptr };
119 mem::forget(self);
120 res
121 }
122}
123
124impl<I: Input> Drop for QueryCursor<'_, '_, I> {
125 fn drop(&mut self) {
126 unsafe { with_cache(|cache| cache.push(InactiveQueryCursor { ptr: self.ptr })) }
127 }
128}
129
130pub struct InactiveQueryCursor {
132 ptr: NonNull<QueryCursorData>,
133}
134
135impl InactiveQueryCursor {
136 #[must_use]
137 pub fn new(range: Range<u32>, limit: u32) -> Self {
138 let mut this = unsafe {
139 with_cache(|cache| {
140 cache.pop().unwrap_or_else(|| InactiveQueryCursor {
141 ptr: NonNull::new_unchecked(ts_query_cursor_new()),
142 })
143 })
144 };
145 this.set_byte_range(range);
146 this.set_match_limit(limit);
147 this
148 }
149
150 #[doc(alias = "ts_query_cursor_match_limit")]
152 #[must_use]
153 pub fn match_limit(&self) -> u32 {
154 unsafe { ts_query_cursor_match_limit(self.ptr.as_ptr()) }
155 }
156
157 #[doc(alias = "ts_query_cursor_set_match_limit")]
160 pub fn set_match_limit(&mut self, limit: u32) {
161 unsafe {
162 ts_query_cursor_set_match_limit(self.ptr.as_ptr(), limit);
163 }
164 }
165
166 #[doc(alias = "ts_query_cursor_did_exceed_match_limit")]
169 #[must_use]
170 pub fn did_exceed_match_limit(&self) -> bool {
171 unsafe { ts_query_cursor_did_exceed_match_limit(self.ptr.as_ptr()) }
172 }
173
174 pub fn set_byte_range(&mut self, range: Range<u32>) {
175 unsafe {
176 ts_query_cursor_set_byte_range(self.ptr.as_ptr(), range.start, range.end);
177 }
178 }
179
180 pub fn execute_query<'a, 'tree, I: IntoInput>(
181 self,
182 query: &'a Query,
183 node: &Node<'tree>,
184 input: I,
185 ) -> QueryCursor<'a, 'tree, I::Input> {
186 let ptr = self.ptr;
187 unsafe { ts_query_cursor_exec(ptr.as_ptr(), query.raw.as_ref(), node.as_raw()) };
188 mem::forget(self);
189 QueryCursor {
190 query,
191 ptr,
192 tree: PhantomData,
193 input: input.into_input(),
194 }
195 }
196}
197
198impl Default for InactiveQueryCursor {
199 fn default() -> Self {
200 Self::new(0..u32::MAX, u32::MAX)
201 }
202}
203
204impl Drop for InactiveQueryCursor {
205 fn drop(&mut self) {
206 unsafe { ts_query_cursor_delete(self.ptr.as_ptr()) }
207 }
208}
209
210pub type MatchedNodeIdx = u32;
211
212#[repr(C)]
213#[derive(Debug, Clone)]
214pub struct MatchedNode<'tree> {
215 pub node: Node<'tree>,
216 pub capture: Capture,
217}
218
219pub struct QueryMatch<'cursor, 'tree> {
220 id: u32,
221 pattern: Pattern,
222 matched_nodes: &'cursor [MatchedNode<'tree>],
223 query_cursor: &'cursor mut QueryCursorData,
224 _tree: PhantomData<&'tree super::Tree>,
225}
226
227impl std::fmt::Debug for QueryMatch<'_, '_> {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 f.debug_struct("QueryMatch")
230 .field("id", &self.id)
231 .field("pattern", &self.pattern)
232 .field("matched_nodes", &self.matched_nodes)
233 .finish_non_exhaustive()
234 }
235}
236
237impl<'tree> QueryMatch<'_, 'tree> {
238 pub fn matched_nodes(&self) -> impl Iterator<Item = &MatchedNode<'tree>> {
239 self.matched_nodes.iter()
240 }
241
242 pub fn nodes_for_capture(&self, capture: Capture) -> impl Iterator<Item = &Node<'tree>> {
243 self.matched_nodes
244 .iter()
245 .filter(move |mat| mat.capture == capture)
246 .map(|mat| &mat.node)
247 }
248
249 pub fn matched_node(&self, i: MatchedNodeIdx) -> &MatchedNode<'tree> {
250 &self.matched_nodes[i as usize]
251 }
252
253 #[must_use]
254 pub const fn id(&self) -> u32 {
255 self.id
256 }
257
258 #[must_use]
259 pub const fn pattern(&self) -> Pattern {
260 self.pattern
261 }
262
263 #[doc(alias = "ts_query_cursor_remove_match")]
264 pub fn remove(self) {
268 unsafe {
269 ts_query_cursor_remove_match(self.query_cursor, self.id);
270 }
271 }
272}
273
274#[repr(C)]
275#[derive(Debug)]
276struct TSQueryCapture {
277 node: NodeRaw,
278 index: u32,
279}
280
281#[repr(C)]
282#[derive(Debug)]
283struct TSQueryMatch {
284 id: u32,
285 pattern_index: u16,
286 capture_count: u16,
287 captures: *const TSQueryCapture,
288}
289
290extern "C" {
291 fn ts_query_cursor_next_capture(
295 self_: *mut QueryCursorData,
296 match_: &mut TSQueryMatch,
297 capture_index: &mut u32,
298 ) -> bool;
299
300 fn ts_query_cursor_next_match(self_: *mut QueryCursorData, match_: &mut TSQueryMatch) -> bool;
305 fn ts_query_cursor_remove_match(self_: *mut QueryCursorData, match_id: u32);
306 fn ts_query_cursor_delete(self_: *mut QueryCursorData);
308 fn ts_query_cursor_new() -> *mut QueryCursorData;
329
330 fn ts_query_cursor_exec(self_: *mut QueryCursorData, query: &QueryData, node: NodeRaw);
332 fn ts_query_cursor_did_exceed_match_limit(self_: *const QueryCursorData) -> bool;
342 fn ts_query_cursor_match_limit(self_: *const QueryCursorData) -> u32;
343 fn ts_query_cursor_set_match_limit(self_: *mut QueryCursorData, limit: u32);
344 fn ts_query_cursor_set_byte_range(self_: *mut QueryCursorData, start_byte: u32, end_byte: u32);
347
348}