Skip to main content

tree_sitter_language_pack/
node.rs

1use std::borrow::Cow;
2
3use crate::Error;
4
5/// Lightweight snapshot of a tree-sitter node's properties.
6///
7/// Contains only primitive types for easy cross-language serialization.
8/// This is an owned type that can be passed across FFI boundaries, unlike
9/// `tree_sitter::Node` which borrows from the tree.
10#[derive(Debug, Clone, PartialEq, Eq)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12pub struct NodeInfo {
13    /// The grammar type name (e.g., "function_definition", "identifier").
14    pub kind: Cow<'static, str>,
15    /// Whether this is a named node (vs anonymous like punctuation).
16    pub is_named: bool,
17    /// Start byte offset in source.
18    pub start_byte: usize,
19    /// End byte offset in source.
20    pub end_byte: usize,
21    /// Start row (zero-indexed).
22    pub start_row: usize,
23    /// Start column (zero-indexed).
24    pub start_col: usize,
25    /// End row (zero-indexed).
26    pub end_row: usize,
27    /// End column (zero-indexed).
28    pub end_col: usize,
29    /// Number of named children.
30    pub named_child_count: usize,
31    /// Whether this node is an ERROR node.
32    pub is_error: bool,
33    /// Whether this node is a MISSING node.
34    pub is_missing: bool,
35}
36
37/// Extract a `NodeInfo` from a tree-sitter `Node`.
38pub fn node_info_from_node(node: tree_sitter::Node) -> NodeInfo {
39    let start = node.start_position();
40    let end = node.end_position();
41    NodeInfo {
42        kind: Cow::Borrowed(node.kind()),
43        is_named: node.is_named(),
44        start_byte: node.start_byte(),
45        end_byte: node.end_byte(),
46        start_row: start.row,
47        start_col: start.column,
48        end_row: end.row,
49        end_col: end.column,
50        named_child_count: node.named_child_count(),
51        is_error: node.is_error(),
52        is_missing: node.is_missing(),
53    }
54}
55
56/// Get a `NodeInfo` snapshot of the root node.
57pub fn root_node_info(tree: &tree_sitter::Tree) -> NodeInfo {
58    node_info_from_node(tree.root_node())
59}
60
61/// Find all nodes matching the given type name, returning their `NodeInfo`.
62///
63/// Performs a depth-first traversal. Returns an empty vec if no matches.
64pub fn find_nodes_by_type(tree: &tree_sitter::Tree, node_type: &str) -> Vec<NodeInfo> {
65    let mut results = Vec::new();
66    let mut cursor = tree.walk();
67    collect_with_cursor(&mut cursor, |node| {
68        if node.kind() == node_type {
69            results.push(node_info_from_node(node));
70        }
71    });
72    results
73}
74
75/// Get `NodeInfo` for all named children of the root node.
76///
77/// Useful for understanding the top-level structure of a file
78/// (e.g., list of function definitions, class declarations, imports).
79pub fn named_children_info(tree: &tree_sitter::Tree) -> Vec<NodeInfo> {
80    let root = tree.root_node();
81    let mut children = Vec::with_capacity(root.named_child_count());
82    let mut cursor = root.walk();
83    if cursor.goto_first_child() {
84        loop {
85            let node = cursor.node();
86            if node.is_named() {
87                children.push(node_info_from_node(node));
88            }
89            if !cursor.goto_next_sibling() {
90                break;
91            }
92        }
93    }
94    children
95}
96
97/// Extract the source text corresponding to a node's byte range.
98///
99/// Returns the slice of source bytes as a UTF-8 string.
100pub fn extract_text<'a>(source: &'a [u8], node_info: &NodeInfo) -> Result<&'a str, Error> {
101    if node_info.end_byte > source.len() {
102        return Err(Error::InvalidRange(format!(
103            "end_byte {} exceeds source length {}",
104            node_info.end_byte,
105            source.len()
106        )));
107    }
108    std::str::from_utf8(&source[node_info.start_byte..node_info.end_byte])
109        .map_err(|e| Error::InvalidRange(format!("not valid UTF-8: {e}")))
110}
111
112/// Visit every node in a depth-first traversal, calling `visitor` on each.
113fn collect_with_cursor(cursor: &mut tree_sitter::TreeCursor, mut visitor: impl FnMut(tree_sitter::Node)) {
114    loop {
115        visitor(cursor.node());
116        if cursor.goto_first_child() {
117            continue;
118        }
119        loop {
120            if cursor.goto_next_sibling() {
121                break;
122            }
123            if !cursor.goto_parent() {
124                return;
125            }
126        }
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    fn parse_first_lang(source: &[u8]) -> Option<tree_sitter::Tree> {
135        let langs = crate::available_languages();
136        let first = langs.first()?;
137        crate::parse::parse_string(first, source).ok()
138    }
139
140    #[test]
141    fn test_root_node_info() {
142        let Some(tree) = parse_first_lang(b"x") else {
143            return;
144        };
145        let info = root_node_info(&tree);
146        assert!(!info.kind.is_empty());
147        assert!(info.is_named);
148        assert_eq!(info.start_byte, 0);
149    }
150
151    #[test]
152    fn test_find_nodes_by_type() {
153        let Some(tree) = parse_first_lang(b"x") else {
154            return;
155        };
156        let root_kind = tree.root_node().kind().to_string();
157        let nodes = find_nodes_by_type(&tree, &root_kind);
158        assert!(!nodes.is_empty());
159        assert_eq!(nodes[0].kind, root_kind);
160    }
161
162    #[test]
163    fn test_find_nodes_by_type_no_match() {
164        let Some(tree) = parse_first_lang(b"x") else {
165            return;
166        };
167        let nodes = find_nodes_by_type(&tree, "nonexistent_node_type_xyz");
168        assert!(nodes.is_empty());
169    }
170
171    #[test]
172    fn test_named_children_info() {
173        let Some(tree) = parse_first_lang(b"x") else {
174            return;
175        };
176        let children = named_children_info(&tree);
177        // Root should have at least one named child for most grammars
178        // (the parsed "x" token), but this depends on the grammar
179        let _ = children;
180    }
181
182    #[test]
183    fn test_extract_text() {
184        let source = b"hello world";
185        let info = NodeInfo {
186            kind: Cow::Owned("test".to_string()),
187            is_named: true,
188            start_byte: 0,
189            end_byte: 5,
190            start_row: 0,
191            start_col: 0,
192            end_row: 0,
193            end_col: 5,
194            named_child_count: 0,
195            is_error: false,
196            is_missing: false,
197        };
198        let text = extract_text(source, &info).unwrap();
199        assert_eq!(text, "hello");
200    }
201
202    #[test]
203    fn test_extract_text_out_of_bounds() {
204        let source = b"hi";
205        let info = NodeInfo {
206            kind: Cow::Owned("test".to_string()),
207            is_named: true,
208            start_byte: 0,
209            end_byte: 100,
210            start_row: 0,
211            start_col: 0,
212            end_row: 0,
213            end_col: 100,
214            named_child_count: 0,
215            is_error: false,
216            is_missing: false,
217        };
218        assert!(extract_text(source, &info).is_err());
219    }
220}