Skip to main content

splice/ingest/
java.rs

1//! Java-specific tree-sitter parsing logic.
2//!
3//! This module contains tree-sitter-java integration for extracting
4//! classes, interfaces, enums, methods, constructors, fields, and other Java constructs with byte spans.
5
6use crate::error::{Result, SpliceError};
7use ropey::Rope;
8use std::path::Path;
9
10/// Represents a Java symbol with its byte and line/col spans.
11#[derive(Debug, Clone, PartialEq)]
12pub struct JavaSymbol {
13    /// Local symbol name (e.g., `MyClass`).
14    pub name: String,
15
16    /// Symbol kind (class, interface, enum, method, constructor, field).
17    pub kind: JavaSymbolKind,
18
19    /// Start byte offset.
20    pub byte_start: usize,
21
22    /// End byte offset.
23    pub byte_end: usize,
24
25    /// Start line (1-based).
26    pub line_start: usize,
27
28    /// End line (1-based).
29    pub line_end: usize,
30
31    /// Start column (0-based, in bytes).
32    pub col_start: usize,
33
34    /// End column (0-based, in bytes).
35    pub col_end: usize,
36
37    /// Method/constructor parameters.
38    pub parameters: Vec<String>,
39
40    /// Class/interface path (e.g., `Outer.Inner`).
41    pub container_path: String,
42
43    /// Fully qualified name (e.g., `Outer.Inner.method`).
44    pub fully_qualified: String,
45
46    /// Whether this has public modifier.
47    pub is_public: bool,
48
49    /// Whether this has static modifier.
50    pub is_static: bool,
51}
52
53/// Kinds of Java symbols.
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum JavaSymbolKind {
56    /// Class symbol.
57    Class,
58    /// Interface symbol.
59    Interface,
60    /// Enum symbol.
61    Enum,
62    /// Method symbol.
63    Method,
64    /// Constructor symbol.
65    Constructor,
66    /// Field symbol.
67    Field,
68}
69
70impl JavaSymbolKind {
71    /// Convert to string for storage.
72    pub fn as_str(&self) -> &'static str {
73        match self {
74            JavaSymbolKind::Class => "class",
75            JavaSymbolKind::Interface => "interface",
76            JavaSymbolKind::Enum => "enum",
77            JavaSymbolKind::Method => "method",
78            JavaSymbolKind::Constructor => "constructor",
79            JavaSymbolKind::Field => "field",
80        }
81    }
82}
83
84/// Extract symbols and spans from a Java source file.
85///
86/// Uses tree-sitter-java to parse the file and extract:
87/// - Class declarations
88/// - Interface declarations
89/// - Enum declarations
90/// - Method declarations
91/// - Constructor declarations
92/// - Field declarations
93///
94/// Returns a list of symbol entries ready for graph insertion.
95pub fn extract_java_symbols(path: &Path, source: &[u8]) -> Result<Vec<JavaSymbol>> {
96    let mut parser = tree_sitter::Parser::new();
97    parser
98        .set_language(&tree_sitter_java::language())
99        .map_err(|e| SpliceError::Parse {
100            file: path.to_path_buf(),
101            message: format!("Failed to set Java language: {:?}", e),
102        })?;
103
104    let tree = parser
105        .parse(source, None)
106        .ok_or_else(|| SpliceError::Parse {
107            file: path.to_path_buf(),
108            message: "Parse failed - no tree returned".to_string(),
109        })?;
110
111    let rope = Rope::from_str(std::str::from_utf8(source)?);
112
113    let mut symbols = Vec::new();
114    extract_symbols(tree.root_node(), source, &rope, &mut symbols, "");
115
116    Ok(symbols)
117}
118
119/// Extract symbols from AST nodes.
120fn extract_symbols(
121    node: tree_sitter::Node,
122    source: &[u8],
123    rope: &Rope,
124    symbols: &mut Vec<JavaSymbol>,
125    container_path: &str,
126) {
127    let kind = node.kind();
128
129    // Check for modifiers
130    let is_public = has_modifier(node, "public");
131    let is_static = has_modifier(node, "static");
132
133    // Determine symbol kind
134    let symbol_kind = match kind {
135        "class_declaration" => Some(JavaSymbolKind::Class),
136        "interface_declaration" => Some(JavaSymbolKind::Interface),
137        "enum_declaration" => Some(JavaSymbolKind::Enum),
138        "method_declaration" => Some(JavaSymbolKind::Method),
139        "constructor_declaration" => Some(JavaSymbolKind::Constructor),
140        "field_declaration" => Some(JavaSymbolKind::Field),
141        _ => None,
142    };
143
144    if let Some(kind) = symbol_kind {
145        if let Some(symbol) = extract_symbol(
146            node,
147            source,
148            rope,
149            kind,
150            container_path,
151            is_public,
152            is_static,
153        ) {
154            let name = symbol.name.clone();
155
156            symbols.push(symbol);
157
158            // For classes, interfaces, and enums, extract nested symbols
159            if matches!(
160                kind,
161                JavaSymbolKind::Class | JavaSymbolKind::Interface | JavaSymbolKind::Enum
162            ) {
163                let new_container = if container_path.is_empty() {
164                    name.clone()
165                } else {
166                    format!("{}.{}", container_path, name)
167                };
168
169                // Extract from class/interface/enum body
170                if let Some(body) = node.child_by_field_name("body") {
171                    extract_symbols(body, source, rope, symbols, &new_container);
172                }
173
174                return;
175            }
176        }
177    }
178
179    // Recursively process children (unless we already handled class/interface/enum bodies)
180    let mut cursor = node.walk();
181    for child in node.children(&mut cursor) {
182        // Skip bodies of classes/interfaces/enums as we handle them above
183        if matches!(
184            kind,
185            "class_declaration" | "interface_declaration" | "enum_declaration"
186        ) && matches!(child.kind(), "class_body" | "interface_body" | "enum_body")
187        {
188            continue;
189        }
190        // Skip declarator children of field_declaration (already handled in extract_name)
191        if kind == "field_declaration" && child.kind() == "variable_declarator" {
192            continue;
193        }
194        extract_symbols(child, source, rope, symbols, container_path);
195    }
196}
197
198/// Check if a node has a specific modifier (public, private, static, etc.).
199fn has_modifier(node: tree_sitter::Node, modifier: &str) -> bool {
200    // Check for modifiers child
201    for child in node.children(&mut node.walk()) {
202        if child.kind() == "modifiers" {
203            for modifier_node in child.children(&mut child.walk()) {
204                if modifier_node.kind() == modifier {
205                    return true;
206                }
207            }
208        }
209    }
210    false
211}
212
213/// Extract a single symbol from a tree-sitter node.
214fn extract_symbol(
215    node: tree_sitter::Node,
216    source: &[u8],
217    rope: &Rope,
218    kind: JavaSymbolKind,
219    container_path: &str,
220    is_public: bool,
221    is_static: bool,
222) -> Option<JavaSymbol> {
223    let name = extract_name(node, source)?;
224
225    let byte_start = node.start_byte();
226    let byte_end = node.end_byte();
227
228    let start_char = rope.byte_to_char(byte_start);
229    let end_char = rope.byte_to_char(byte_end);
230
231    let line_start = rope.char_to_line(start_char);
232    let line_end = rope.char_to_line(end_char);
233
234    let line_start_byte = rope.line_to_byte(line_start);
235    let line_end_byte = rope.line_to_byte(line_end);
236
237    let col_start = byte_start - line_start_byte;
238    let col_end = byte_end - line_end_byte;
239
240    let parameters = extract_parameters(node, source);
241
242    let fully_qualified = if container_path.is_empty() {
243        name.clone()
244    } else {
245        format!("{}.{}", container_path, name)
246    };
247
248    Some(JavaSymbol {
249        name,
250        kind,
251        byte_start,
252        byte_end,
253        line_start: line_start + 1,
254        line_end: line_end + 1,
255        col_start,
256        col_end,
257        parameters,
258        container_path: container_path.to_string(),
259        fully_qualified,
260        is_public,
261        is_static,
262    })
263}
264
265/// Extract the name from a node.
266fn extract_name(node: tree_sitter::Node, source: &[u8]) -> Option<String> {
267    let kind = node.kind();
268
269    match kind {
270        "class_declaration" | "interface_declaration" | "enum_declaration" => node
271            .child_by_field_name("name")
272            .and_then(|n| n.utf8_text(source).ok().map(|s| s.to_string())),
273        "method_declaration" | "constructor_declaration" => node
274            .child_by_field_name("name")
275            .and_then(|n| n.utf8_text(source).ok().map(|s| s.to_string())),
276        "field_declaration" => {
277            // For field declarations, get the name from the first declarator
278            for child in node.children(&mut node.walk()) {
279                if child.kind() == "variable_declarator" {
280                    if let Some(name_node) = child.child_by_field_name("name") {
281                        if let Ok(name) = name_node.utf8_text(source) {
282                            return Some(name.to_string());
283                        }
284                    }
285                }
286            }
287            None
288        }
289        _ => None,
290    }
291}
292
293/// Extract parameter names from a method/constructor.
294fn extract_parameters(node: tree_sitter::Node, source: &[u8]) -> Vec<String> {
295    let mut parameters = Vec::new();
296
297    if let Some(params) = node.child_by_field_name("parameters") {
298        for param in params.children(&mut params.walk()) {
299            if param.kind() == "formal_parameter" {
300                if let Some(name_node) = param.child_by_field_name("name") {
301                    if let Ok(name) = name_node.utf8_text(source) {
302                        parameters.push(name.to_string());
303                    }
304                }
305            }
306        }
307    }
308
309    parameters
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_extract_simple_class() {
318        let source = b"class MyClass {}\n";
319        let path = Path::new("test.java");
320        let result = extract_java_symbols(path, source);
321        assert!(result.is_ok());
322        let symbols = result.unwrap();
323        assert_eq!(symbols.len(), 1);
324        assert_eq!(symbols[0].name, "MyClass");
325        assert_eq!(symbols[0].kind.as_str(), "class");
326    }
327
328    #[test]
329    fn test_extract_class_with_method() {
330        let source = b"class MyClass { void method() {} }\n";
331        let path = Path::new("test.java");
332        let result = extract_java_symbols(path, source);
333        assert!(result.is_ok());
334        let symbols = result.unwrap();
335        // Class + method = 2 symbols
336        assert_eq!(symbols.len(), 2);
337        assert_eq!(symbols[0].name, "MyClass");
338        assert_eq!(symbols[0].kind.as_str(), "class");
339        assert_eq!(symbols[1].name, "method");
340        assert_eq!(symbols[1].kind.as_str(), "method");
341    }
342
343    #[test]
344    fn test_extract_class_with_field() {
345        let source = b"class MyClass { private int field; }\n";
346        let path = Path::new("test.java");
347        let result = extract_java_symbols(path, source);
348        assert!(result.is_ok());
349        let symbols = result.unwrap();
350        assert_eq!(symbols.len(), 2);
351        assert_eq!(symbols[0].name, "MyClass");
352        assert_eq!(symbols[1].name, "field");
353        assert_eq!(symbols[1].kind.as_str(), "field");
354    }
355
356    #[test]
357    fn test_extract_interface() {
358        let source = b"interface MyInterface { void method(); }\n";
359        let path = Path::new("test.java");
360        let result = extract_java_symbols(path, source);
361        assert!(result.is_ok());
362        let symbols = result.unwrap();
363        assert_eq!(symbols.len(), 2);
364        assert_eq!(symbols[0].name, "MyInterface");
365        assert_eq!(symbols[0].kind.as_str(), "interface");
366        assert_eq!(symbols[1].name, "method");
367        assert_eq!(symbols[1].kind.as_str(), "method");
368    }
369
370    #[test]
371    fn test_extract_enum() {
372        let source = b"enum Color { RED, GREEN, BLUE }\n";
373        let path = Path::new("test.java");
374        let result = extract_java_symbols(path, source);
375        assert!(result.is_ok());
376        let symbols = result.unwrap();
377        assert_eq!(symbols.len(), 1);
378        assert_eq!(symbols[0].name, "Color");
379        assert_eq!(symbols[0].kind.as_str(), "enum");
380    }
381
382    #[test]
383    fn test_extract_class_with_constructor() {
384        let source = b"class Foo { Foo() {} }\n";
385        let path = Path::new("test.java");
386        let result = extract_java_symbols(path, source);
387        assert!(result.is_ok());
388        let symbols = result.unwrap();
389        assert_eq!(symbols.len(), 2);
390        assert_eq!(symbols[0].name, "Foo");
391        assert_eq!(symbols[0].kind.as_str(), "class");
392        assert_eq!(symbols[1].name, "Foo");
393        assert_eq!(symbols[1].kind.as_str(), "constructor");
394    }
395
396    #[test]
397    fn test_extract_method_with_parameters() {
398        let source = b"class MyClass { void add(int a, int b) {} }\n";
399        let path = Path::new("test.java");
400        let result = extract_java_symbols(path, source);
401        assert!(result.is_ok());
402        let symbols = result.unwrap();
403        assert_eq!(symbols.len(), 2);
404        assert_eq!(symbols[1].parameters, vec!["a", "b"]);
405    }
406
407    #[test]
408    fn test_extract_public_class() {
409        let source = b"public class MyClass {}\n";
410        let path = Path::new("test.java");
411        let result = extract_java_symbols(path, source);
412        assert!(result.is_ok());
413        let symbols = result.unwrap();
414        assert_eq!(symbols.len(), 1);
415        assert_eq!(symbols[0].name, "MyClass");
416        assert!(symbols[0].is_public);
417    }
418
419    #[test]
420    fn test_extract_static_method() {
421        let source = b"class MyClass { static void method() {} }\n";
422        let path = Path::new("test.java");
423        let result = extract_java_symbols(path, source);
424        assert!(result.is_ok());
425        let symbols = result.unwrap();
426        assert_eq!(symbols.len(), 2);
427        assert_eq!(symbols[1].name, "method");
428        assert!(symbols[1].is_static);
429    }
430}