tree_sitter_language_pack/
query.rs1use std::borrow::Cow;
2use std::cell::RefCell;
3use std::sync::{Arc, LazyLock, RwLock};
4
5use crate::Error;
6use crate::node::{NodeInfo, node_info_from_node};
7use tree_sitter::StreamingIterator;
8
9#[derive(Debug)]
10struct CompiledQuery {
11 query: tree_sitter::Query,
12 capture_names: Vec<Cow<'static, str>>,
13}
14
15type QueryCacheMap = ahash::AHashMap<(String, String), Arc<CompiledQuery>>;
16
17static QUERY_CACHE: LazyLock<RwLock<QueryCacheMap>> = LazyLock::new(|| RwLock::new(QueryCacheMap::new()));
18
19thread_local! {
20 static LOCAL_QUERY_CACHE: RefCell<QueryCacheMap> = RefCell::new(QueryCacheMap::new());
21}
22
23#[derive(Debug, Clone, Default, PartialEq, Eq)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26pub struct QueryMatch {
27 pub pattern_index: usize,
29 pub captures: Vec<(Cow<'static, str>, NodeInfo)>,
31}
32
33pub fn run_query(
62 tree: &tree_sitter::Tree,
63 language: &str,
64 query_source: &str,
65 source: &[u8],
66) -> Result<Vec<QueryMatch>, Error> {
67 let query = compiled_query(language, query_source)?;
68
69 let mut cursor = tree_sitter::QueryCursor::new();
70 let mut matches = cursor.matches(&query.query, tree.root_node(), source);
71
72 let mut results = Vec::new();
79 while let Some(m) = matches.next() {
80 let captures = m
81 .captures
82 .iter()
83 .map(|c| {
84 let name = query.capture_names[c.index as usize].clone();
85 let info = node_info_from_node(c.node);
86 (name, info)
87 })
88 .collect();
89 results.push(QueryMatch {
90 pattern_index: m.pattern_index,
91 captures,
92 });
93 }
94 Ok(results)
95}
96
97fn compiled_query(language: &str, query_source: &str) -> Result<Arc<CompiledQuery>, Error> {
98 let key = (language.to_string(), query_source.to_string());
99 if let Some(query) = LOCAL_QUERY_CACHE.with(|cache| cache.borrow().get(&key).cloned()) {
100 return Ok(query);
101 }
102 if let Some(query) = QUERY_CACHE
103 .read()
104 .map_err(|e| Error::LockPoisoned(e.to_string()))?
105 .get(&key)
106 .cloned()
107 {
108 LOCAL_QUERY_CACHE.with(|cache| {
109 cache.borrow_mut().insert(key, Arc::clone(&query));
110 });
111 return Ok(query);
112 }
113
114 let lang = crate::get_language(language)?;
115 let query = tree_sitter::Query::new(&lang, query_source).map_err(|e| Error::QueryError(format!("{e}")))?;
116 let capture_names = query
117 .capture_names()
118 .iter()
119 .map(|s| Cow::Owned(s.to_string()))
120 .collect();
121 let compiled = Arc::new(CompiledQuery { query, capture_names });
122 LOCAL_QUERY_CACHE.with(|cache| {
123 cache.borrow_mut().insert(key.clone(), Arc::clone(&compiled));
124 });
125 let mut global = QUERY_CACHE.write().map_err(|e| Error::LockPoisoned(e.to_string()))?;
126 Ok(global.entry(key).or_insert_with(|| Arc::clone(&compiled)).clone())
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_run_query_invalid_language() {
135 let langs = crate::available_languages();
137 if langs.is_empty() {
138 return;
139 }
140 let tree = crate::parse::parse_string(&langs[0], b"x").unwrap();
141 let result = run_query(&tree, "nonexistent_xyz", "(identifier) @id", b"x");
142 assert!(result.is_err());
143 }
144
145 #[test]
146 fn test_run_query_invalid_pattern() {
147 let langs = crate::available_languages();
148 if langs.is_empty() {
149 return;
150 }
151 let first = &langs[0];
152 let tree = crate::parse::parse_string(first, b"x").unwrap();
153 let result = run_query(&tree, first, "((((invalid syntax", b"x");
154 assert!(result.is_err());
155 }
156
157 #[test]
158 fn test_run_query_no_matches() {
159 let langs = crate::available_languages();
160 if langs.is_empty() {
161 return;
162 }
163 let first = &langs[0];
164 let tree = crate::parse::parse_string(first, b"x").unwrap();
165 let result = run_query(&tree, first, "(function_definition) @fn", b"x");
167 if let Ok(matches) = result {
170 assert!(matches.is_empty());
171 }
172 }
174
175 #[test]
176 fn test_compiled_query_reused() {
177 let langs = crate::available_languages();
178 if langs.is_empty() {
179 return;
180 }
181 for lang in &langs {
183 let query_src = "(identifier) @reuse_check";
184 let q1 = match compiled_query(lang, query_src) {
185 Ok(q) => q,
186 Err(_) => continue,
187 };
188 let q2 = compiled_query(lang, query_src).unwrap();
189 assert!(
190 Arc::ptr_eq(&q1, &q2),
191 "repeated compiled_query for '{lang}' should return same Arc"
192 );
193 return;
194 }
195 }
196
197 #[test]
198 fn test_different_languages_same_query_separate_cache() {
199 let langs = crate::available_languages();
200 if langs.len() < 2 {
201 return;
202 }
203 let query_src = "(identifier) @id";
204 let q1 = compiled_query(&langs[0], query_src);
205 let q2 = compiled_query(&langs[1], query_src);
206 if let (Ok(q1), Ok(q2)) = (q1, q2) {
209 assert!(
210 !Arc::ptr_eq(&q1, &q2),
211 "different languages should produce different cached queries"
212 );
213 }
214 }
215
216 #[test]
217 fn test_compiled_query_error_recovery() {
218 let langs = crate::available_languages();
219 if langs.is_empty() {
220 return;
221 }
222 let first = &langs[0];
223 let bad = compiled_query(first, "((((invalid syntax");
225 assert!(bad.is_err());
226 let good = compiled_query(first, "(identifier) @id");
228 let _ = good;
230 }
231
232 #[test]
233 fn test_compiled_query_capture_names_preserved() {
234 let langs = crate::available_languages();
235 if langs.is_empty() {
236 return;
237 }
238 let first = &langs[0];
239 let q = compiled_query(first, "(identifier) @name");
240 if let Ok(q) = q {
241 assert!(!q.capture_names.is_empty(), "capture_names should not be empty");
242 assert_eq!(q.capture_names[0], "name");
243 }
244 }
245
246 #[test]
247 fn test_compiled_query_shared_across_threads() {
248 let langs = crate::available_languages();
249 if langs.is_empty() {
250 return;
251 }
252 let lang = langs[0].clone();
253 let query_src = "(identifier) @id";
254 let q_main = compiled_query(&lang, query_src);
256 if q_main.is_err() {
257 return; }
259 let q_main = q_main.unwrap();
260
261 let lang_clone = lang.clone();
262 let handle = std::thread::spawn(move || compiled_query(&lang_clone, query_src));
263 let q_thread = handle.join().expect("thread should not panic");
264 if let Ok(q_thread) = q_thread {
265 assert!(
266 Arc::ptr_eq(&q_main, &q_thread),
267 "same query from different threads should share the same Arc via global cache"
268 );
269 }
270 }
271}