Skip to main content

tree_sitter_language_pack/
query.rs

1use std::borrow::Cow;
2use std::cell::RefCell;
3use std::sync::{Arc, LazyLock, RwLock};
4
5use crate::Error;
6use crate::node::{NodeInfo, node_info_from_node};
7use tree_sitter::StreamingIterator;
8
9#[derive(Debug)]
10struct CompiledQuery {
11    query: tree_sitter::Query,
12    capture_names: Vec<Cow<'static, str>>,
13}
14
15type QueryCacheMap = ahash::AHashMap<(String, String), Arc<CompiledQuery>>;
16
17static QUERY_CACHE: LazyLock<RwLock<QueryCacheMap>> = LazyLock::new(|| RwLock::new(QueryCacheMap::new()));
18
19thread_local! {
20    static LOCAL_QUERY_CACHE: RefCell<QueryCacheMap> = RefCell::new(QueryCacheMap::new());
21}
22
23/// A single match from a tree-sitter query, with captured nodes.
24#[derive(Debug, Clone, Default, PartialEq, Eq)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26pub struct QueryMatch {
27    /// The pattern index that matched (position in the query string).
28    pub pattern_index: usize,
29    /// Captures: list of (capture_name, node_info) pairs.
30    pub captures: Vec<(Cow<'static, str>, NodeInfo)>,
31}
32
33/// Execute a tree-sitter query pattern against a parsed tree.
34///
35/// The `query_source` is an S-expression pattern like:
36/// ```text
37/// (function_definition name: (identifier) @name)
38/// ```
39///
40/// Returns all matches with their captured nodes.
41///
42/// # Arguments
43///
44/// * `tree` - The parsed syntax tree to query.
45/// * `language` - Language name (used to compile the query pattern).
46/// * `query_source` - The tree-sitter query pattern string.
47/// * `source` - The original source code bytes (needed for capture resolution).
48///
49/// # Examples
50///
51/// ```no_run
52/// let tree = tree_sitter_language_pack::parse_string("python", b"def hello(): pass").unwrap();
53/// let matches = tree_sitter_language_pack::run_query(
54///     &tree,
55///     "python",
56///     "(function_definition name: (identifier) @fn_name)",
57///     b"def hello(): pass",
58/// ).unwrap();
59/// assert!(!matches.is_empty());
60/// ```
61pub fn run_query(
62    tree: &tree_sitter::Tree,
63    language: &str,
64    query_source: &str,
65    source: &[u8],
66) -> Result<Vec<QueryMatch>, Error> {
67    let query = compiled_query(language, query_source)?;
68
69    let mut cursor = tree_sitter::QueryCursor::new();
70    let mut matches = cursor.matches(&query.query, tree.root_node(), source);
71
72    // Tree-sitter 0.26+ evaluates standard text predicates (`#eq?`, `#not-eq?`,
73    // `#match?`, `#not-match?`, `#any-of?`, `#not-any-of?`) internally via
74    // `satisfies_text_predicates()` during `QueryCursor::matches()` iteration.
75    // The `general_predicates()` method only returns predicates with operators
76    // that tree-sitter does NOT recognize (i.e., custom predicates). Since we
77    // don't define any custom predicates, no additional filtering is needed.
78    let mut results = Vec::new();
79    while let Some(m) = matches.next() {
80        let captures = m
81            .captures
82            .iter()
83            .map(|c| {
84                let name = query.capture_names[c.index as usize].clone();
85                let info = node_info_from_node(c.node);
86                (name, info)
87            })
88            .collect();
89        results.push(QueryMatch {
90            pattern_index: m.pattern_index,
91            captures,
92        });
93    }
94    Ok(results)
95}
96
97fn compiled_query(language: &str, query_source: &str) -> Result<Arc<CompiledQuery>, Error> {
98    let key = (language.to_string(), query_source.to_string());
99    if let Some(query) = LOCAL_QUERY_CACHE.with(|cache| cache.borrow().get(&key).cloned()) {
100        return Ok(query);
101    }
102    if let Some(query) = QUERY_CACHE
103        .read()
104        .map_err(|e| Error::LockPoisoned(e.to_string()))?
105        .get(&key)
106        .cloned()
107    {
108        LOCAL_QUERY_CACHE.with(|cache| {
109            cache.borrow_mut().insert(key, Arc::clone(&query));
110        });
111        return Ok(query);
112    }
113
114    let lang = crate::get_language(language)?;
115    let query = tree_sitter::Query::new(&lang, query_source).map_err(|e| Error::QueryError(format!("{e}")))?;
116    let capture_names = query
117        .capture_names()
118        .iter()
119        .map(|s| Cow::Owned(s.to_string()))
120        .collect();
121    let compiled = Arc::new(CompiledQuery { query, capture_names });
122    LOCAL_QUERY_CACHE.with(|cache| {
123        cache.borrow_mut().insert(key.clone(), Arc::clone(&compiled));
124    });
125    let mut global = QUERY_CACHE.write().map_err(|e| Error::LockPoisoned(e.to_string()))?;
126    Ok(global.entry(key).or_insert_with(|| Arc::clone(&compiled)).clone())
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_run_query_invalid_language() {
135        // Create a dummy tree from any available language
136        let langs = crate::available_languages();
137        if langs.is_empty() {
138            return;
139        }
140        let tree = crate::parse::parse_string(&langs[0], b"x").unwrap();
141        let result = run_query(&tree, "nonexistent_xyz", "(identifier) @id", b"x");
142        assert!(result.is_err());
143    }
144
145    #[test]
146    fn test_run_query_invalid_pattern() {
147        let langs = crate::available_languages();
148        if langs.is_empty() {
149            return;
150        }
151        let first = &langs[0];
152        let tree = crate::parse::parse_string(first, b"x").unwrap();
153        let result = run_query(&tree, first, "((((invalid syntax", b"x");
154        assert!(result.is_err());
155    }
156
157    #[test]
158    fn test_run_query_no_matches() {
159        let langs = crate::available_languages();
160        if langs.is_empty() {
161            return;
162        }
163        let first = &langs[0];
164        let tree = crate::parse::parse_string(first, b"x").unwrap();
165        // Query for a node type that is unlikely to exist for a single "x"
166        let result = run_query(&tree, first, "(function_definition) @fn", b"x");
167        // This might error if the grammar doesn't have function_definition,
168        // or return empty matches. Either is acceptable.
169        if let Ok(matches) = result {
170            assert!(matches.is_empty());
171        }
172        // Query compilation error is fine for some grammars
173    }
174
175    #[test]
176    fn test_compiled_query_reused() {
177        let langs = crate::available_languages();
178        if langs.is_empty() {
179            return;
180        }
181        // Try each language until we find one where the query compiles.
182        for lang in &langs {
183            let query_src = "(identifier) @reuse_check";
184            let q1 = match compiled_query(lang, query_src) {
185                Ok(q) => q,
186                Err(_) => continue,
187            };
188            let q2 = compiled_query(lang, query_src).unwrap();
189            assert!(
190                Arc::ptr_eq(&q1, &q2),
191                "repeated compiled_query for '{lang}' should return same Arc"
192            );
193            return;
194        }
195    }
196
197    #[test]
198    fn test_different_languages_same_query_separate_cache() {
199        let langs = crate::available_languages();
200        if langs.len() < 2 {
201            return;
202        }
203        let query_src = "(identifier) @id";
204        let q1 = compiled_query(&langs[0], query_src);
205        let q2 = compiled_query(&langs[1], query_src);
206        // Both might fail if grammar doesn't have identifiers, but if both succeed
207        // they should be different Arc pointers.
208        if let (Ok(q1), Ok(q2)) = (q1, q2) {
209            assert!(
210                !Arc::ptr_eq(&q1, &q2),
211                "different languages should produce different cached queries"
212            );
213        }
214    }
215
216    #[test]
217    fn test_compiled_query_error_recovery() {
218        let langs = crate::available_languages();
219        if langs.is_empty() {
220            return;
221        }
222        let first = &langs[0];
223        // Invalid query should fail
224        let bad = compiled_query(first, "((((invalid syntax");
225        assert!(bad.is_err());
226        // Valid query should still work after a failed compilation
227        let good = compiled_query(first, "(identifier) @id");
228        // May fail for some grammars without identifiers, but should not panic
229        let _ = good;
230    }
231
232    #[test]
233    fn test_compiled_query_capture_names_preserved() {
234        let langs = crate::available_languages();
235        if langs.is_empty() {
236            return;
237        }
238        let first = &langs[0];
239        let q = compiled_query(first, "(identifier) @name");
240        if let Ok(q) = q {
241            assert!(!q.capture_names.is_empty(), "capture_names should not be empty");
242            assert_eq!(q.capture_names[0], "name");
243        }
244    }
245
246    #[test]
247    fn test_compiled_query_shared_across_threads() {
248        let langs = crate::available_languages();
249        if langs.is_empty() {
250            return;
251        }
252        let lang = langs[0].clone();
253        let query_src = "(identifier) @id";
254        // Prime the global cache from this thread
255        let q_main = compiled_query(&lang, query_src);
256        if q_main.is_err() {
257            return; // Grammar doesn't support this query
258        }
259        let q_main = q_main.unwrap();
260
261        let lang_clone = lang.clone();
262        let handle = std::thread::spawn(move || compiled_query(&lang_clone, query_src));
263        let q_thread = handle.join().expect("thread should not panic");
264        if let Ok(q_thread) = q_thread {
265            assert!(
266                Arc::ptr_eq(&q_main, &q_thread),
267                "same query from different threads should share the same Arc via global cache"
268            );
269        }
270    }
271}