1use crate::error::{Result, SpliceError};
7use ropey::Rope;
8use std::path::Path;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct JavaSymbol {
13 pub name: String,
15
16 pub kind: JavaSymbolKind,
18
19 pub byte_start: usize,
21
22 pub byte_end: usize,
24
25 pub line_start: usize,
27
28 pub line_end: usize,
30
31 pub col_start: usize,
33
34 pub col_end: usize,
36
37 pub parameters: Vec<String>,
39
40 pub container_path: String,
42
43 pub fully_qualified: String,
45
46 pub is_public: bool,
48
49 pub is_static: bool,
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum JavaSymbolKind {
56 Class,
58 Interface,
60 Enum,
62 Method,
64 Constructor,
66 Field,
68}
69
70impl JavaSymbolKind {
71 pub fn as_str(&self) -> &'static str {
73 match self {
74 JavaSymbolKind::Class => "class",
75 JavaSymbolKind::Interface => "interface",
76 JavaSymbolKind::Enum => "enum",
77 JavaSymbolKind::Method => "method",
78 JavaSymbolKind::Constructor => "constructor",
79 JavaSymbolKind::Field => "field",
80 }
81 }
82}
83
84pub fn extract_java_symbols(path: &Path, source: &[u8]) -> Result<Vec<JavaSymbol>> {
96 let mut parser = tree_sitter::Parser::new();
97 parser
98 .set_language(&tree_sitter_java::language())
99 .map_err(|e| SpliceError::Parse {
100 file: path.to_path_buf(),
101 message: format!("Failed to set Java language: {:?}", e),
102 })?;
103
104 let tree = parser
105 .parse(source, None)
106 .ok_or_else(|| SpliceError::Parse {
107 file: path.to_path_buf(),
108 message: "Parse failed - no tree returned".to_string(),
109 })?;
110
111 let rope = Rope::from_str(std::str::from_utf8(source)?);
112
113 let mut symbols = Vec::new();
114 extract_symbols(tree.root_node(), source, &rope, &mut symbols, "");
115
116 Ok(symbols)
117}
118
119fn extract_symbols(
121 node: tree_sitter::Node,
122 source: &[u8],
123 rope: &Rope,
124 symbols: &mut Vec<JavaSymbol>,
125 container_path: &str,
126) {
127 let kind = node.kind();
128
129 let is_public = has_modifier(node, "public");
131 let is_static = has_modifier(node, "static");
132
133 let symbol_kind = match kind {
135 "class_declaration" => Some(JavaSymbolKind::Class),
136 "interface_declaration" => Some(JavaSymbolKind::Interface),
137 "enum_declaration" => Some(JavaSymbolKind::Enum),
138 "method_declaration" => Some(JavaSymbolKind::Method),
139 "constructor_declaration" => Some(JavaSymbolKind::Constructor),
140 "field_declaration" => Some(JavaSymbolKind::Field),
141 _ => None,
142 };
143
144 if let Some(kind) = symbol_kind {
145 if let Some(symbol) = extract_symbol(
146 node,
147 source,
148 rope,
149 kind,
150 container_path,
151 is_public,
152 is_static,
153 ) {
154 let name = symbol.name.clone();
155
156 symbols.push(symbol);
157
158 if matches!(
160 kind,
161 JavaSymbolKind::Class | JavaSymbolKind::Interface | JavaSymbolKind::Enum
162 ) {
163 let new_container = if container_path.is_empty() {
164 name.clone()
165 } else {
166 format!("{}.{}", container_path, name)
167 };
168
169 if let Some(body) = node.child_by_field_name("body") {
171 extract_symbols(body, source, rope, symbols, &new_container);
172 }
173
174 return;
175 }
176 }
177 }
178
179 let mut cursor = node.walk();
181 for child in node.children(&mut cursor) {
182 if matches!(
184 kind,
185 "class_declaration" | "interface_declaration" | "enum_declaration"
186 ) && matches!(child.kind(), "class_body" | "interface_body" | "enum_body")
187 {
188 continue;
189 }
190 if kind == "field_declaration" && child.kind() == "variable_declarator" {
192 continue;
193 }
194 extract_symbols(child, source, rope, symbols, container_path);
195 }
196}
197
198fn has_modifier(node: tree_sitter::Node, modifier: &str) -> bool {
200 for child in node.children(&mut node.walk()) {
202 if child.kind() == "modifiers" {
203 for modifier_node in child.children(&mut child.walk()) {
204 if modifier_node.kind() == modifier {
205 return true;
206 }
207 }
208 }
209 }
210 false
211}
212
213fn extract_symbol(
215 node: tree_sitter::Node,
216 source: &[u8],
217 rope: &Rope,
218 kind: JavaSymbolKind,
219 container_path: &str,
220 is_public: bool,
221 is_static: bool,
222) -> Option<JavaSymbol> {
223 let name = extract_name(node, source)?;
224
225 let byte_start = node.start_byte();
226 let byte_end = node.end_byte();
227
228 let start_char = rope.byte_to_char(byte_start);
229 let end_char = rope.byte_to_char(byte_end);
230
231 let line_start = rope.char_to_line(start_char);
232 let line_end = rope.char_to_line(end_char);
233
234 let line_start_byte = rope.line_to_byte(line_start);
235 let line_end_byte = rope.line_to_byte(line_end);
236
237 let col_start = byte_start - line_start_byte;
238 let col_end = byte_end - line_end_byte;
239
240 let parameters = extract_parameters(node, source);
241
242 let fully_qualified = if container_path.is_empty() {
243 name.clone()
244 } else {
245 format!("{}.{}", container_path, name)
246 };
247
248 Some(JavaSymbol {
249 name,
250 kind,
251 byte_start,
252 byte_end,
253 line_start: line_start + 1,
254 line_end: line_end + 1,
255 col_start,
256 col_end,
257 parameters,
258 container_path: container_path.to_string(),
259 fully_qualified,
260 is_public,
261 is_static,
262 })
263}
264
265fn extract_name(node: tree_sitter::Node, source: &[u8]) -> Option<String> {
267 let kind = node.kind();
268
269 match kind {
270 "class_declaration" | "interface_declaration" | "enum_declaration" => node
271 .child_by_field_name("name")
272 .and_then(|n| n.utf8_text(source).ok().map(|s| s.to_string())),
273 "method_declaration" | "constructor_declaration" => node
274 .child_by_field_name("name")
275 .and_then(|n| n.utf8_text(source).ok().map(|s| s.to_string())),
276 "field_declaration" => {
277 for child in node.children(&mut node.walk()) {
279 if child.kind() == "variable_declarator" {
280 if let Some(name_node) = child.child_by_field_name("name") {
281 if let Ok(name) = name_node.utf8_text(source) {
282 return Some(name.to_string());
283 }
284 }
285 }
286 }
287 None
288 }
289 _ => None,
290 }
291}
292
293fn extract_parameters(node: tree_sitter::Node, source: &[u8]) -> Vec<String> {
295 let mut parameters = Vec::new();
296
297 if let Some(params) = node.child_by_field_name("parameters") {
298 for param in params.children(&mut params.walk()) {
299 if param.kind() == "formal_parameter" {
300 if let Some(name_node) = param.child_by_field_name("name") {
301 if let Ok(name) = name_node.utf8_text(source) {
302 parameters.push(name.to_string());
303 }
304 }
305 }
306 }
307 }
308
309 parameters
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_extract_simple_class() {
318 let source = b"class MyClass {}\n";
319 let path = Path::new("test.java");
320 let result = extract_java_symbols(path, source);
321 assert!(result.is_ok());
322 let symbols = result.unwrap();
323 assert_eq!(symbols.len(), 1);
324 assert_eq!(symbols[0].name, "MyClass");
325 assert_eq!(symbols[0].kind.as_str(), "class");
326 }
327
328 #[test]
329 fn test_extract_class_with_method() {
330 let source = b"class MyClass { void method() {} }\n";
331 let path = Path::new("test.java");
332 let result = extract_java_symbols(path, source);
333 assert!(result.is_ok());
334 let symbols = result.unwrap();
335 assert_eq!(symbols.len(), 2);
337 assert_eq!(symbols[0].name, "MyClass");
338 assert_eq!(symbols[0].kind.as_str(), "class");
339 assert_eq!(symbols[1].name, "method");
340 assert_eq!(symbols[1].kind.as_str(), "method");
341 }
342
343 #[test]
344 fn test_extract_class_with_field() {
345 let source = b"class MyClass { private int field; }\n";
346 let path = Path::new("test.java");
347 let result = extract_java_symbols(path, source);
348 assert!(result.is_ok());
349 let symbols = result.unwrap();
350 assert_eq!(symbols.len(), 2);
351 assert_eq!(symbols[0].name, "MyClass");
352 assert_eq!(symbols[1].name, "field");
353 assert_eq!(symbols[1].kind.as_str(), "field");
354 }
355
356 #[test]
357 fn test_extract_interface() {
358 let source = b"interface MyInterface { void method(); }\n";
359 let path = Path::new("test.java");
360 let result = extract_java_symbols(path, source);
361 assert!(result.is_ok());
362 let symbols = result.unwrap();
363 assert_eq!(symbols.len(), 2);
364 assert_eq!(symbols[0].name, "MyInterface");
365 assert_eq!(symbols[0].kind.as_str(), "interface");
366 assert_eq!(symbols[1].name, "method");
367 assert_eq!(symbols[1].kind.as_str(), "method");
368 }
369
370 #[test]
371 fn test_extract_enum() {
372 let source = b"enum Color { RED, GREEN, BLUE }\n";
373 let path = Path::new("test.java");
374 let result = extract_java_symbols(path, source);
375 assert!(result.is_ok());
376 let symbols = result.unwrap();
377 assert_eq!(symbols.len(), 1);
378 assert_eq!(symbols[0].name, "Color");
379 assert_eq!(symbols[0].kind.as_str(), "enum");
380 }
381
382 #[test]
383 fn test_extract_class_with_constructor() {
384 let source = b"class Foo { Foo() {} }\n";
385 let path = Path::new("test.java");
386 let result = extract_java_symbols(path, source);
387 assert!(result.is_ok());
388 let symbols = result.unwrap();
389 assert_eq!(symbols.len(), 2);
390 assert_eq!(symbols[0].name, "Foo");
391 assert_eq!(symbols[0].kind.as_str(), "class");
392 assert_eq!(symbols[1].name, "Foo");
393 assert_eq!(symbols[1].kind.as_str(), "constructor");
394 }
395
396 #[test]
397 fn test_extract_method_with_parameters() {
398 let source = b"class MyClass { void add(int a, int b) {} }\n";
399 let path = Path::new("test.java");
400 let result = extract_java_symbols(path, source);
401 assert!(result.is_ok());
402 let symbols = result.unwrap();
403 assert_eq!(symbols.len(), 2);
404 assert_eq!(symbols[1].parameters, vec!["a", "b"]);
405 }
406
407 #[test]
408 fn test_extract_public_class() {
409 let source = b"public class MyClass {}\n";
410 let path = Path::new("test.java");
411 let result = extract_java_symbols(path, source);
412 assert!(result.is_ok());
413 let symbols = result.unwrap();
414 assert_eq!(symbols.len(), 1);
415 assert_eq!(symbols[0].name, "MyClass");
416 assert!(symbols[0].is_public);
417 }
418
419 #[test]
420 fn test_extract_static_method() {
421 let source = b"class MyClass { static void method() {} }\n";
422 let path = Path::new("test.java");
423 let result = extract_java_symbols(path, source);
424 assert!(result.is_ok());
425 let symbols = result.unwrap();
426 assert_eq!(symbols.len(), 2);
427 assert_eq!(symbols[1].name, "method");
428 assert!(symbols[1].is_static);
429 }
430}